Skip to main content

indicators/regime/
primitives.rs

1//! Technical Indicators for Regime Detection
2//!
3//! Self-contained indicator implementations used by the regime detection system.
4//! Provides EMA, ATR, ADX, and Bollinger Bands calculations optimized for
5//! market regime classification.
6//!
7//! These are intentionally kept within the regime crate rather than depending on
8//! `indicators`, because:
9//! 1. The regime crate needs specific indicator semantics (e.g., ADX with DI crossover)
10//! 2. Keeps the crate self-contained with zero internal dependencies
11//! 3. `indicators` can later delegate to these if desired
12
13use std::collections::{HashMap, VecDeque};
14
15use super::types::TrendDirection;
16
17use crate::error::IndicatorError;
18use crate::indicator::{Indicator, IndicatorOutput};
19use crate::registry::param_usize;
20use crate::types::Candle;
21
22// ── Indicator wrappers ────────────────────────────────────────────────────────
23
24/// Batch `Indicator` wrapping the regime-internal [`ADX`] primitive.
25///
26/// Outputs `adx`, `di_plus`, and `di_minus` per bar.
27#[derive(Debug, Clone)]
28pub struct AdxIndicator {
29    pub period: usize,
30}
31
32impl AdxIndicator {
33    pub fn new(period: usize) -> Self {
34        Self { period }
35    }
36}
37
38impl Indicator for AdxIndicator {
39    fn name(&self) -> &'static str {
40        "ADX"
41    }
42    fn required_len(&self) -> usize {
43        self.period * 2
44    }
45    fn required_columns(&self) -> &[&'static str] {
46        &["high", "low", "close"]
47    }
48
49    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
50        self.check_len(candles)?;
51        let mut adx_calc = ADX::new(self.period);
52        let n = candles.len();
53        let mut adx_out = vec![f64::NAN; n];
54        let mut dip_out = vec![f64::NAN; n];
55        let mut dmi_out = vec![f64::NAN; n];
56        for (i, c) in candles.iter().enumerate() {
57            if let Some(v) = adx_calc.update(c.high, c.low, c.close) {
58                adx_out[i] = v;
59                dip_out[i] = adx_calc.di_plus().unwrap_or(f64::NAN);
60                dmi_out[i] = adx_calc.di_minus().unwrap_or(f64::NAN);
61            }
62        }
63        Ok(IndicatorOutput::from_pairs([
64            ("adx", adx_out),
65            ("di_plus", dip_out),
66            ("di_minus", dmi_out),
67        ]))
68    }
69}
70
71/// Batch `Indicator` wrapping the regime-internal [`ATR`] primitive.
72#[derive(Debug, Clone)]
73pub struct AtrPrimIndicator {
74    pub period: usize,
75}
76
77impl AtrPrimIndicator {
78    pub fn new(period: usize) -> Self {
79        Self { period }
80    }
81}
82
83impl Indicator for AtrPrimIndicator {
84    fn name(&self) -> &'static str {
85        "AtrPrim"
86    }
87    fn required_len(&self) -> usize {
88        self.period + 1
89    }
90    fn required_columns(&self) -> &[&'static str] {
91        &["high", "low", "close"]
92    }
93
94    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
95        self.check_len(candles)?;
96        let mut atr_calc = ATR::new(self.period);
97        let n = candles.len();
98        let mut out = vec![f64::NAN; n];
99        for (i, c) in candles.iter().enumerate() {
100            if let Some(v) = atr_calc.update(c.high, c.low, c.close) {
101                out[i] = v;
102            }
103        }
104        Ok(IndicatorOutput::from_pairs([("atr_prim", out)]))
105    }
106}
107
108/// Batch `Indicator` wrapping the regime-internal [`EMA`] primitive.
109#[derive(Debug, Clone)]
110pub struct EmaPrimIndicator {
111    pub period: usize,
112}
113
114impl EmaPrimIndicator {
115    pub fn new(period: usize) -> Self {
116        Self { period }
117    }
118}
119
120impl Indicator for EmaPrimIndicator {
121    fn name(&self) -> &'static str {
122        "EmaPrim"
123    }
124    fn required_len(&self) -> usize {
125        self.period
126    }
127    fn required_columns(&self) -> &[&'static str] {
128        &["close"]
129    }
130
131    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
132        self.check_len(candles)?;
133        let mut ema_calc = EMA::new(self.period);
134        let n = candles.len();
135        let mut out = vec![f64::NAN; n];
136        for (i, c) in candles.iter().enumerate() {
137            if let Some(v) = ema_calc.update(c.close) {
138                out[i] = v;
139            }
140        }
141        Ok(IndicatorOutput::from_pairs([("ema_prim", out)]))
142    }
143}
144
145/// Batch `Indicator` wrapping the regime-internal [`RSI`] primitive.
146#[derive(Debug, Clone)]
147pub struct RsiPrimIndicator {
148    pub period: usize,
149}
150
151impl RsiPrimIndicator {
152    pub fn new(period: usize) -> Self {
153        Self { period }
154    }
155}
156
157impl Indicator for RsiPrimIndicator {
158    fn name(&self) -> &'static str {
159        "RsiPrim"
160    }
161    fn required_len(&self) -> usize {
162        self.period + 1
163    }
164    fn required_columns(&self) -> &[&'static str] {
165        &["close"]
166    }
167
168    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
169        self.check_len(candles)?;
170        let mut rsi_calc = RSI::new(self.period);
171        let n = candles.len();
172        let mut out = vec![f64::NAN; n];
173        for (i, c) in candles.iter().enumerate() {
174            if let Some(v) = rsi_calc.update(c.close) {
175                out[i] = v;
176            }
177        }
178        Ok(IndicatorOutput::from_pairs([("rsi_prim", out)]))
179    }
180}
181
182/// Batch `Indicator` wrapping the regime-internal [`BollingerBands`] primitive.
183///
184/// Outputs `bb_upper`, `bb_mid`, `bb_lower`, and `bb_width` per bar.
185#[derive(Debug, Clone)]
186pub struct BbPrimIndicator {
187    pub period: usize,
188    pub std_dev: f64,
189}
190
191impl BbPrimIndicator {
192    pub fn new(period: usize, std_dev: f64) -> Self {
193        Self { period, std_dev }
194    }
195}
196
197impl Indicator for BbPrimIndicator {
198    fn name(&self) -> &'static str {
199        "BbPrim"
200    }
201    fn required_len(&self) -> usize {
202        self.period
203    }
204    fn required_columns(&self) -> &[&'static str] {
205        &["close"]
206    }
207
208    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
209        self.check_len(candles)?;
210        let mut bb = BollingerBands::new(self.period, self.std_dev);
211        let n = candles.len();
212        let mut upper = vec![f64::NAN; n];
213        let mut mid = vec![f64::NAN; n];
214        let mut lower = vec![f64::NAN; n];
215        let mut width = vec![f64::NAN; n];
216        for (i, c) in candles.iter().enumerate() {
217            if let Some(v) = bb.update(c.close) {
218                upper[i] = v.upper;
219                mid[i] = v.middle;
220                lower[i] = v.lower;
221                width[i] = v.width;
222            }
223        }
224        Ok(IndicatorOutput::from_pairs([
225            ("bb_upper", upper),
226            ("bb_mid", mid),
227            ("bb_lower", lower),
228            ("bb_width", width),
229        ]))
230    }
231}
232
233// ── Registry factory ──────────────────────────────────────────────────────────
234
235/// Default factory registers as `"primitives"` → produces [`AdxIndicator`].
236/// Use the individual wrapper structs directly for EMA, ATR, RSI, or BB.
237pub fn factory<S: ::std::hash::BuildHasher>(params: &HashMap<String, String, S>) -> Result<Box<dyn Indicator>, IndicatorError> {
238    let period = param_usize(params, "period", 14)?;
239    Ok(Box::new(AdxIndicator::new(period)))
240}
241
242// ============================================================================
243// Exponential Moving Average (EMA)
244// ============================================================================
245
246/// Exponential Moving Average calculator
247///
248/// Uses the standard EMA formula: EMA_t = price * k + EMA_{t-1} * (1 - k)
249/// where k = 2 / (period + 1)
250#[derive(Debug, Clone)]
251pub struct EMA {
252    period: usize,
253    multiplier: f64,
254    current_value: Option<f64>,
255    initialized: bool,
256    warmup_count: usize,
257}
258
259impl EMA {
260    /// Create a new EMA with the given period
261    pub fn new(period: usize) -> Self {
262        let multiplier = 2.0 / (period as f64 + 1.0);
263        Self {
264            period,
265            multiplier,
266            current_value: None,
267            initialized: false,
268            warmup_count: 0,
269        }
270    }
271
272    /// Update with a new price value, returning the EMA if warmed up
273    pub fn update(&mut self, price: f64) -> Option<f64> {
274        self.warmup_count += 1;
275
276        match self.current_value {
277            Some(prev_ema) => {
278                let new_ema = (price - prev_ema) * self.multiplier + prev_ema;
279                self.current_value = Some(new_ema);
280
281                if self.warmup_count >= self.period {
282                    self.initialized = true;
283                }
284            }
285            None => {
286                self.current_value = Some(price);
287            }
288        }
289
290        if self.initialized {
291            self.current_value
292        } else {
293            None
294        }
295    }
296
297    /// Get the current EMA value (None if not yet warmed up)
298    pub fn value(&self) -> Option<f64> {
299        if self.initialized {
300            self.current_value
301        } else {
302            None
303        }
304    }
305
306    /// Check if the EMA has enough data to produce valid values
307    pub fn is_ready(&self) -> bool {
308        self.initialized
309    }
310
311    /// Get the period
312    pub fn period(&self) -> usize {
313        self.period
314    }
315
316    /// Reset the EMA state
317    pub fn reset(&mut self) {
318        self.current_value = None;
319        self.initialized = false;
320        self.warmup_count = 0;
321    }
322}
323
324// ============================================================================
325// Average True Range (ATR)
326// ============================================================================
327
328/// Average True Range (ATR) calculator
329///
330/// Uses Wilder's smoothing method for the ATR calculation.
331/// True Range = max(High - Low, |High - PrevClose|, |Low - PrevClose|)
332#[derive(Debug, Clone)]
333pub struct ATR {
334    period: usize,
335    values: VecDeque<f64>,
336    prev_close: Option<f64>,
337    current_atr: Option<f64>,
338}
339
340impl ATR {
341    /// Create a new ATR with the given period
342    pub fn new(period: usize) -> Self {
343        Self {
344            period,
345            values: VecDeque::with_capacity(period),
346            prev_close: None,
347            current_atr: None,
348        }
349    }
350
351    /// Update with OHLC data, returning the ATR if warmed up
352    pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
353        let true_range = match self.prev_close {
354            Some(prev_c) => {
355                let hl = high - low;
356                let hc = (high - prev_c).abs();
357                let lc = (low - prev_c).abs();
358                hl.max(hc).max(lc)
359            }
360            None => high - low,
361        };
362
363        self.prev_close = Some(close);
364        self.values.push_back(true_range);
365
366        if self.values.len() > self.period {
367            self.values.pop_front();
368        }
369
370        if self.values.len() >= self.period {
371            // Use Wilder's smoothing method
372            if let Some(prev_atr) = self.current_atr {
373                let new_atr =
374                    (prev_atr * (self.period - 1) as f64 + true_range) / self.period as f64;
375                self.current_atr = Some(new_atr);
376            } else {
377                let sum: f64 = self.values.iter().sum();
378                self.current_atr = Some(sum / self.period as f64);
379            }
380        }
381
382        self.current_atr
383    }
384
385    /// Get the current ATR value
386    pub fn value(&self) -> Option<f64> {
387        self.current_atr
388    }
389
390    /// Check if the ATR has enough data
391    pub fn is_ready(&self) -> bool {
392        self.current_atr.is_some()
393    }
394
395    /// Get the period
396    pub fn period(&self) -> usize {
397        self.period
398    }
399
400    /// Reset the ATR state
401    pub fn reset(&mut self) {
402        self.values.clear();
403        self.prev_close = None;
404        self.current_atr = None;
405    }
406}
407
408// ============================================================================
409// Average Directional Index (ADX)
410// ============================================================================
411
412/// Average Directional Index (ADX) calculator
413///
414/// Measures trend strength (not direction). Values above 25 typically indicate
415/// a strong trend, while values below 20 suggest a ranging market.
416///
417/// Also provides +DI and -DI for trend direction via `trend_direction()`.
418#[derive(Debug, Clone)]
419pub struct ADX {
420    period: usize,
421    atr: ATR,
422    plus_dm_ema: EMA,
423    minus_dm_ema: EMA,
424    dx_values: VecDeque<f64>,
425    prev_high: Option<f64>,
426    prev_low: Option<f64>,
427    current_adx: Option<f64>,
428    plus_dir_index: Option<f64>,
429    minus_dir_index: Option<f64>,
430}
431
432impl ADX {
433    /// Create a new ADX with the given period
434    pub fn new(period: usize) -> Self {
435        Self {
436            period,
437            atr: ATR::new(period),
438            plus_dm_ema: EMA::new(period),
439            minus_dm_ema: EMA::new(period),
440            dx_values: VecDeque::with_capacity(period),
441            prev_high: None,
442            prev_low: None,
443            current_adx: None,
444            plus_dir_index: None,
445            minus_dir_index: None,
446        }
447    }
448
449    /// Update with HLC data, returning the ADX value if warmed up
450    pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
451        // Calculate directional movement
452        let (plus_dm, minus_dm) = match (self.prev_high, self.prev_low) {
453            (Some(prev_h), Some(prev_l)) => {
454                let up_move = high - prev_h;
455                let down_move = prev_l - low;
456
457                let plus = if up_move > down_move && up_move > 0.0 {
458                    up_move
459                } else {
460                    0.0
461                };
462
463                let minus = if down_move > up_move && down_move > 0.0 {
464                    down_move
465                } else {
466                    0.0
467                };
468
469                (plus, minus)
470            }
471            _ => (0.0, 0.0),
472        };
473
474        self.prev_high = Some(high);
475        self.prev_low = Some(low);
476
477        // Update ATR
478        let atr = self.atr.update(high, low, close);
479
480        // Smooth directional movement
481        let smoothed_plus_dm = self.plus_dm_ema.update(plus_dm);
482        let smoothed_minus_dm = self.minus_dm_ema.update(minus_dm);
483
484        // Calculate DI values
485        if let (Some(atr_val), Some(plus_dm_smooth), Some(minus_dm_smooth)) =
486            (atr, smoothed_plus_dm, smoothed_minus_dm)
487            && atr_val > 0.0
488        {
489            let plus_dir_index = (plus_dm_smooth / atr_val) * 100.0;
490            let minus_dir_index = (minus_dm_smooth / atr_val) * 100.0;
491            self.plus_dir_index = Some(plus_dir_index);
492            self.minus_dir_index = Some(minus_dir_index);
493
494            // Calculate DX
495            let di_sum = plus_dir_index + minus_dir_index;
496            if di_sum > 0.0 {
497                let di_diff = (plus_dir_index - minus_dir_index).abs();
498                let dx = (di_diff / di_sum) * 100.0;
499
500                self.dx_values.push_back(dx);
501                if self.dx_values.len() > self.period {
502                    self.dx_values.pop_front();
503                }
504
505                // Calculate ADX as smoothed DX
506                if self.dx_values.len() >= self.period {
507                    if let Some(prev_adx) = self.current_adx {
508                        let new_adx =
509                            (prev_adx * (self.period - 1) as f64 + dx) / self.period as f64;
510                        self.current_adx = Some(new_adx);
511                    } else {
512                        let sum: f64 = self.dx_values.iter().sum();
513                        self.current_adx = Some(sum / self.period as f64);
514                    }
515                }
516            }
517        }
518
519        self.current_adx
520    }
521
522    /// Get the current ADX value
523    pub fn value(&self) -> Option<f64> {
524        self.current_adx
525    }
526
527    /// Get the +DI value
528    pub fn plus_dir_index(&self) -> Option<f64> {
529        self.plus_dir_index
530    }
531
532    /// Get the -DI value
533    pub fn minus_dir_index(&self) -> Option<f64> {
534        self.minus_dir_index
535    }
536
537    /// Returns trend direction based on DI crossover.
538    ///
539    /// - `+DI > -DI` → Bullish
540    /// - `-DI > +DI` → Bearish
541    pub fn trend_direction(&self) -> Option<TrendDirection> {
542        match (self.plus_dir_index, self.minus_dir_index) {
543            (Some(plus), Some(minus)) => {
544                if plus > minus {
545                    Some(TrendDirection::Bullish)
546                } else {
547                    Some(TrendDirection::Bearish)
548                }
549            }
550            _ => None,
551        }
552    }
553
554    /// Check if the ADX has enough data
555    pub fn is_ready(&self) -> bool {
556        self.current_adx.is_some()
557    }
558
559    /// Get the period
560    pub fn period(&self) -> usize {
561        self.period
562    }
563
564    /// Current DI+ value (directional index plus), available after warm-up.
565    pub fn di_plus(&self) -> Option<f64> {
566        self.plus_dir_index
567    }
568
569    /// Current DI- value (directional index minus), available after warm-up.
570    pub fn di_minus(&self) -> Option<f64> {
571        self.minus_dir_index
572    }
573
574    /// Reset the ADX state
575    pub fn reset(&mut self) {
576        self.atr.reset();
577        self.plus_dm_ema.reset();
578        self.minus_dm_ema.reset();
579        self.dx_values.clear();
580        self.prev_high = None;
581        self.prev_low = None;
582        self.current_adx = None;
583        self.plus_dir_index = None;
584        self.minus_dir_index = None;
585    }
586}
587
588// ============================================================================
589// Bollinger Bands
590// ============================================================================
591
592/// Bollinger Bands output values
593#[derive(Debug, Clone, Copy)]
594pub struct BollingerBandsValues {
595    /// Upper band (SMA + n * σ)
596    pub upper: f64,
597    /// Middle band (SMA)
598    pub middle: f64,
599    /// Lower band (SMA - n * σ)
600    pub lower: f64,
601    /// Band width as percentage of price
602    pub width: f64,
603    /// Where current width ranks historically (0–100 percentile)
604    pub width_percentile: f64,
605    /// Where price is within the bands (0.0 = lower, 1.0 = upper)
606    pub percent_b: f64,
607    /// Standard deviation of prices
608    pub std_dev: f64,
609}
610
611impl BollingerBandsValues {
612    /// Is price overbought (near or above upper band)?
613    pub fn is_overbought(&self) -> bool {
614        self.percent_b >= 0.95
615    }
616
617    /// Is price oversold (near or below lower band)?
618    pub fn is_oversold(&self) -> bool {
619        self.percent_b <= 0.05
620    }
621
622    /// Are bands wide (high volatility)?
623    pub fn is_high_volatility(&self, threshold_percentile: f64) -> bool {
624        self.width_percentile >= threshold_percentile
625    }
626
627    /// Are bands narrow (potential breakout coming)?
628    pub fn is_squeeze(&self, threshold_percentile: f64) -> bool {
629        self.width_percentile <= threshold_percentile
630    }
631}
632
633/// Bollinger Bands calculator
634///
635/// Computes upper, lower, and middle bands along with band width percentile
636/// for volatility regime classification.
637#[derive(Debug, Clone)]
638pub struct BollingerBands {
639    period: usize,
640    std_dev_multiplier: f64,
641    prices: VecDeque<f64>,
642    width_history: VecDeque<f64>,
643    width_history_size: usize,
644}
645
646impl BollingerBands {
647    /// Create a new Bollinger Bands calculator
648    ///
649    /// # Arguments
650    /// * `period` - Lookback period for the SMA (typically 20)
651    /// * `std_dev_multiplier` - Standard deviation multiplier (typically 2.0)
652    pub fn new(period: usize, std_dev_multiplier: f64) -> Self {
653        Self {
654            period,
655            std_dev_multiplier,
656            prices: VecDeque::with_capacity(period),
657            width_history: VecDeque::with_capacity(100),
658            width_history_size: 100, // Keep 100 periods for percentile calc
659        }
660    }
661
662    /// Update with a new price, returning band values if warmed up
663    pub fn update(&mut self, price: f64) -> Option<BollingerBandsValues> {
664        self.prices.push_back(price);
665        if self.prices.len() > self.period {
666            self.prices.pop_front();
667        }
668
669        if self.prices.len() < self.period {
670            return None;
671        }
672
673        // Calculate SMA (middle band)
674        let sum: f64 = self.prices.iter().sum();
675        let sma = sum / self.period as f64;
676
677        // Calculate standard deviation
678        let variance: f64 =
679            self.prices.iter().map(|p| (p - sma).powi(2)).sum::<f64>() / self.period as f64;
680        let std_dev = variance.sqrt();
681
682        // Calculate bands
683        let upper = sma + (std_dev * self.std_dev_multiplier);
684        let lower = sma - (std_dev * self.std_dev_multiplier);
685        let width = if sma > 0.0 {
686            (upper - lower) / sma * 100.0 // Width as percentage of price
687        } else {
688            0.0
689        };
690
691        // Update width history for percentile calculation
692        self.width_history.push_back(width);
693        if self.width_history.len() > self.width_history_size {
694            self.width_history.pop_front();
695        }
696
697        // Calculate width percentile
698        let width_percentile = self.calculate_width_percentile(width);
699
700        // Calculate %B (where price is within bands)
701        let percent_b = if upper - lower > 0.0 {
702            (price - lower) / (upper - lower)
703        } else {
704            0.5
705        };
706
707        Some(BollingerBandsValues {
708            upper,
709            middle: sma,
710            lower,
711            width,
712            width_percentile,
713            percent_b,
714            std_dev,
715        })
716    }
717
718    /// Calculate where the current width ranks in recent history
719    fn calculate_width_percentile(&self, current_width: f64) -> f64 {
720        if self.width_history.len() < 10 {
721            return 50.0; // Not enough data
722        }
723
724        let count_below = self
725            .width_history
726            .iter()
727            .filter(|&&w| w < current_width)
728            .count();
729
730        (count_below as f64 / self.width_history.len() as f64) * 100.0
731    }
732
733    /// Check if the Bollinger Bands have enough data
734    pub fn is_ready(&self) -> bool {
735        self.prices.len() >= self.period
736    }
737
738    /// Get the period
739    pub fn period(&self) -> usize {
740        self.period
741    }
742
743    /// Get the standard deviation multiplier
744    pub fn std_dev_multiplier(&self) -> f64 {
745        self.std_dev_multiplier
746    }
747
748    /// Reset the Bollinger Bands state
749    pub fn reset(&mut self) {
750        self.prices.clear();
751        self.width_history.clear();
752    }
753}
754
755// ============================================================================
756// RSI (Relative Strength Index)
757// ============================================================================
758
759/// Relative Strength Index (RSI) calculator
760///
761/// Uses EMA-smoothed gains and losses for a responsive RSI calculation.
762/// Values above 70 indicate overbought, below 30 indicate oversold.
763#[derive(Debug, Clone)]
764pub struct RSI {
765    period: usize,
766    gains: EMA,
767    losses: EMA,
768    prev_close: Option<f64>,
769    last_rsi: Option<f64>,
770}
771
772impl RSI {
773    /// Create a new RSI with the given period (typically 14)
774    pub fn new(period: usize) -> Self {
775        Self {
776            period,
777            gains: EMA::new(period),
778            losses: EMA::new(period),
779            prev_close: None,
780            last_rsi: None,
781        }
782    }
783
784    /// Update with a new close price, returning the RSI if warmed up
785    pub fn update(&mut self, close: f64) -> Option<f64> {
786        if let Some(prev) = self.prev_close {
787            let change = close - prev;
788            let gain = if change > 0.0 { change } else { 0.0 };
789            let loss = if change < 0.0 { -change } else { 0.0 };
790
791            if let (Some(avg_gain), Some(avg_loss)) =
792                (self.gains.update(gain), self.losses.update(loss))
793            {
794                self.prev_close = Some(close);
795
796                let rsi = if avg_loss == 0.0 {
797                    100.0
798                } else {
799                    let rs = avg_gain / avg_loss;
800                    100.0 - (100.0 / (1.0 + rs))
801                };
802                self.last_rsi = Some(rsi);
803                return self.last_rsi;
804            }
805        }
806
807        self.prev_close = Some(close);
808        None
809    }
810
811    /// Get the most recent RSI value without consuming a new price tick.
812    ///
813    /// Returns `None` until the indicator has completed its warm-up period.
814    pub fn value(&self) -> Option<f64> {
815        self.last_rsi
816    }
817
818    /// Check if RSI has enough data
819    pub fn is_ready(&self) -> bool {
820        self.gains.is_ready() && self.losses.is_ready()
821    }
822
823    /// Get the period
824    pub fn period(&self) -> usize {
825        self.period
826    }
827
828    /// Reset the RSI state
829    pub fn reset(&mut self) {
830        self.gains.reset();
831        self.losses.reset();
832        self.prev_close = None;
833        self.last_rsi = None;
834    }
835}
836
837// ============================================================================
838// Helper Functions
839// ============================================================================
840
841/// Calculate a Simple Moving Average from a slice of values
842pub fn calculate_sma(prices: &[f64]) -> f64 {
843    if prices.is_empty() {
844        return 0.0;
845    }
846    prices.iter().sum::<f64>() / prices.len() as f64
847}
848
849// ============================================================================
850// Tests
851// ============================================================================
852
853#[cfg(test)]
854mod tests {
855    use super::*;
856
857    // --- EMA Tests ---
858
859    #[test]
860    fn test_ema_creation() {
861        let ema = EMA::new(10);
862        assert_eq!(ema.period(), 10);
863        assert!(!ema.is_ready());
864        assert!(ema.value().is_none());
865    }
866
867    #[test]
868    fn test_ema_warmup() {
869        let mut ema = EMA::new(10);
870
871        // Should return None during warmup
872        for i in 1..10 {
873            let result = ema.update(i as f64 * 10.0);
874            assert!(result.is_none(), "Should be None during warmup at step {i}");
875        }
876
877        // Should return Some after warmup
878        let result = ema.update(100.0);
879        assert!(result.is_some(), "Should be ready after {0} updates", 10);
880        assert!(ema.is_ready());
881    }
882
883    #[test]
884    fn test_ema_calculation() {
885        let mut ema = EMA::new(10);
886
887        // Warm up
888        for i in 1..=10 {
889            ema.update(i as f64 * 10.0);
890        }
891
892        assert!(ema.is_ready());
893        let value = ema.value().unwrap();
894        // EMA should be between the min and max input values
895        assert!(value > 10.0 && value <= 100.0);
896    }
897
898    #[test]
899    fn test_ema_tracks_trend() {
900        let mut ema = EMA::new(5);
901
902        // Warm up with constant price
903        for _ in 0..5 {
904            ema.update(100.0);
905        }
906        let stable = ema.value().unwrap();
907
908        // Feed higher prices
909        for _ in 0..10 {
910            ema.update(110.0);
911        }
912        let after_up = ema.value().unwrap();
913
914        assert!(after_up > stable, "EMA should increase with rising prices");
915    }
916
917    #[test]
918    fn test_ema_reset() {
919        let mut ema = EMA::new(5);
920        for _ in 0..10 {
921            ema.update(100.0);
922        }
923        assert!(ema.is_ready());
924
925        ema.reset();
926        assert!(!ema.is_ready());
927        assert!(ema.value().is_none());
928    }
929
930    // --- ATR Tests ---
931
932    #[test]
933    fn test_atr_creation() {
934        let atr = ATR::new(14);
935        assert_eq!(atr.period(), 14);
936        assert!(!atr.is_ready());
937    }
938
939    #[test]
940    fn test_atr_warmup() {
941        let mut atr = ATR::new(14);
942
943        for i in 1..=14 {
944            let base = 100.0 + i as f64;
945            let result = atr.update(base + 1.0, base - 1.0, base);
946            if i < 14 {
947                assert!(result.is_none());
948            }
949        }
950
951        assert!(atr.is_ready());
952    }
953
954    #[test]
955    fn test_atr_increases_with_volatility() {
956        let mut atr = ATR::new(14);
957
958        // Low volatility warmup
959        for i in 1..=14 {
960            let base = 100.0 + i as f64 * 0.1;
961            atr.update(base + 0.5, base - 0.5, base);
962        }
963        let low_vol_atr = atr.value().unwrap();
964
965        // High volatility bars
966        for i in 0..20 {
967            let base = 100.0 + if i % 2 == 0 { 5.0 } else { -5.0 };
968            atr.update(base + 3.0, base - 3.0, base);
969        }
970        let high_vol_atr = atr.value().unwrap();
971
972        assert!(
973            high_vol_atr > low_vol_atr,
974            "ATR should increase with volatility: {high_vol_atr} vs {low_vol_atr}"
975        );
976    }
977
978    #[test]
979    fn test_atr_reset() {
980        let mut atr = ATR::new(14);
981        for i in 0..20 {
982            let base = 100.0 + i as f64;
983            atr.update(base + 1.0, base - 1.0, base);
984        }
985        assert!(atr.is_ready());
986
987        atr.reset();
988        assert!(!atr.is_ready());
989        assert!(atr.value().is_none());
990    }
991
992    // --- ADX Tests ---
993
994    #[test]
995    fn test_adx_creation() {
996        let adx = ADX::new(14);
997        assert_eq!(adx.period(), 14);
998        assert!(!adx.is_ready());
999    }
1000
1001    #[test]
1002    fn test_adx_trending_detection() {
1003        let mut adx = ADX::new(14);
1004
1005        // Simulate strong uptrend (prices going up steadily)
1006        for i in 1..=50 {
1007            let high = 100.0 + i as f64 * 2.0;
1008            let low = 100.0 + i as f64 * 2.0 - 1.0;
1009            let close = 100.0 + i as f64 * 2.0 - 0.5;
1010            adx.update(high, low, close);
1011        }
1012
1013        if let Some(adx_value) = adx.value() {
1014            assert!(
1015                adx_value > 20.0,
1016                "ADX should indicate trend in strong uptrend: {adx_value}"
1017            );
1018        }
1019    }
1020
1021    #[test]
1022    fn test_adx_trend_direction() {
1023        let mut adx = ADX::new(14);
1024
1025        // Strong uptrend
1026        for i in 1..=50 {
1027            let high = 100.0 + i as f64 * 2.0;
1028            let low = 100.0 + i as f64 * 2.0 - 1.0;
1029            let close = 100.0 + i as f64 * 2.0 - 0.5;
1030            adx.update(high, low, close);
1031        }
1032
1033        if let Some(dir) = adx.trend_direction() {
1034            assert_eq!(
1035                dir,
1036                TrendDirection::Bullish,
1037                "Should detect bullish direction in uptrend"
1038            );
1039        }
1040    }
1041
1042    #[test]
1043    fn test_adx_di_values() {
1044        let mut adx = ADX::new(14);
1045
1046        for i in 1..=50 {
1047            let high = 100.0 + i as f64 * 2.0;
1048            let low = 100.0 + i as f64 * 2.0 - 1.0;
1049            let close = 100.0 + i as f64 * 2.0 - 0.5;
1050            adx.update(high, low, close);
1051        }
1052
1053        // In an uptrend, +DI should be higher than -DI
1054        if let (Some(plus), Some(minus)) = (adx.plus_dir_index(), adx.minus_dir_index()) {
1055            assert!(
1056                plus > minus,
1057                "+DI ({plus}) should be > -DI ({minus}) in uptrend"
1058            );
1059        }
1060    }
1061
1062    #[test]
1063    fn test_adx_reset() {
1064        let mut adx = ADX::new(14);
1065        for i in 1..=50 {
1066            let base = 100.0 + i as f64;
1067            adx.update(base + 1.0, base - 1.0, base);
1068        }
1069        assert!(adx.is_ready());
1070
1071        adx.reset();
1072        assert!(!adx.is_ready());
1073        assert!(adx.value().is_none());
1074        assert!(adx.plus_dir_index().is_none());
1075        assert!(adx.minus_dir_index().is_none());
1076    }
1077
1078    // --- Bollinger Bands Tests ---
1079
1080    #[test]
1081    fn test_bb_creation() {
1082        let bb = BollingerBands::new(20, 2.0);
1083        assert_eq!(bb.period(), 20);
1084        assert_eq!(bb.std_dev_multiplier(), 2.0);
1085        assert!(!bb.is_ready());
1086    }
1087
1088    #[test]
1089    fn test_bb_warmup() {
1090        let mut bb = BollingerBands::new(20, 2.0);
1091
1092        for i in 1..20 {
1093            let result = bb.update(100.0 + i as f64 * 0.1);
1094            assert!(result.is_none());
1095        }
1096
1097        let result = bb.update(102.0);
1098        assert!(result.is_some());
1099        assert!(bb.is_ready());
1100    }
1101
1102    #[test]
1103    fn test_bb_band_ordering() {
1104        let mut bb = BollingerBands::new(20, 2.0);
1105
1106        for i in 1..=25 {
1107            let price = 100.0 + (i as f64 % 5.0);
1108            bb.update(price);
1109        }
1110
1111        let result = bb.update(102.0).unwrap();
1112        assert!(
1113            result.upper > result.middle,
1114            "Upper band ({}) should be > middle ({})",
1115            result.upper,
1116            result.middle
1117        );
1118        assert!(
1119            result.middle > result.lower,
1120            "Middle ({}) should be > lower ({})",
1121            result.middle,
1122            result.lower
1123        );
1124    }
1125
1126    #[test]
1127    fn test_bb_percent_b() {
1128        let mut bb = BollingerBands::new(20, 2.0);
1129
1130        // Build some history with variance
1131        for i in 1..=20 {
1132            bb.update(100.0 + (i as f64 % 3.0));
1133        }
1134
1135        // Price at middle should give %B near 0.5
1136        let values = bb.update(100.0 + 1.0);
1137        if let Some(v) = values {
1138            // %B should be between 0 and 1 for normal prices
1139            assert!(
1140                v.percent_b >= 0.0 && v.percent_b <= 1.0,
1141                "%B should be in [0,1]: {}",
1142                v.percent_b
1143            );
1144        }
1145    }
1146
1147    #[test]
1148    fn test_bb_squeeze_detection() {
1149        let mut bb = BollingerBands::new(20, 2.0);
1150
1151        // First, create wide bands with volatile data
1152        for i in 0..50 {
1153            let price = 100.0 + if i % 2 == 0 { 10.0 } else { -10.0 };
1154            bb.update(price);
1155        }
1156
1157        // Then tighten with constant price
1158        for _ in 0..50 {
1159            bb.update(100.0);
1160        }
1161
1162        let result = bb.update(100.0).unwrap();
1163        // After constant prices, width percentile should be low
1164        assert!(
1165            result.width_percentile < 50.0,
1166            "Constant prices should produce low width percentile: {}",
1167            result.width_percentile
1168        );
1169    }
1170
1171    #[test]
1172    fn test_bb_overbought_oversold() {
1173        let mut bb = BollingerBands::new(20, 2.0);
1174
1175        // Build history around 100
1176        for _ in 0..20 {
1177            bb.update(100.0);
1178        }
1179
1180        // Price far above should be overbought
1181        let result = bb.update(110.0).unwrap();
1182        assert!(
1183            result.is_overbought(),
1184            "Price far above bands should be overbought, %B = {}",
1185            result.percent_b
1186        );
1187    }
1188
1189    #[test]
1190    fn test_bb_reset() {
1191        let mut bb = BollingerBands::new(20, 2.0);
1192        for i in 0..25 {
1193            bb.update(100.0 + i as f64);
1194        }
1195        assert!(bb.is_ready());
1196
1197        bb.reset();
1198        assert!(!bb.is_ready());
1199    }
1200
1201    // --- RSI Tests ---
1202
1203    #[test]
1204    fn test_rsi_creation() {
1205        let rsi = RSI::new(14);
1206        assert_eq!(rsi.period(), 14);
1207        assert!(!rsi.is_ready());
1208    }
1209
1210    #[test]
1211    fn test_rsi_bullish_market() {
1212        let mut rsi = RSI::new(14);
1213
1214        // Consistently rising prices
1215        let mut last_rsi = None;
1216        for i in 0..30 {
1217            let price = 100.0 + i as f64;
1218            if let Some(val) = rsi.update(price) {
1219                last_rsi = Some(val);
1220            }
1221        }
1222
1223        if let Some(val) = last_rsi {
1224            assert!(
1225                val > 50.0,
1226                "RSI should be above 50 in bullish market: {val}"
1227            );
1228        }
1229    }
1230
1231    #[test]
1232    fn test_rsi_bearish_market() {
1233        let mut rsi = RSI::new(14);
1234
1235        // Consistently falling prices
1236        let mut last_rsi = None;
1237        for i in 0..30 {
1238            let price = 200.0 - i as f64;
1239            if let Some(val) = rsi.update(price) {
1240                last_rsi = Some(val);
1241            }
1242        }
1243
1244        if let Some(val) = last_rsi {
1245            assert!(
1246                val < 50.0,
1247                "RSI should be below 50 in bearish market: {val}"
1248            );
1249        }
1250    }
1251
1252    #[test]
1253    fn test_rsi_range() {
1254        let mut rsi = RSI::new(14);
1255
1256        for i in 0..50 {
1257            let price = 100.0 + (i as f64 * 0.7).sin() * 10.0;
1258            if let Some(val) = rsi.update(price) {
1259                assert!(
1260                    (0.0..=100.0).contains(&val),
1261                    "RSI should be in [0, 100]: {val}"
1262                );
1263            }
1264        }
1265    }
1266
1267    #[test]
1268    fn test_rsi_value_cached() {
1269        let mut rsi = RSI::new(14);
1270        assert!(
1271            rsi.value().is_none(),
1272            "value() should be None before warmup"
1273        );
1274
1275        let mut last_from_update = None;
1276        for i in 0..30 {
1277            let price = 100.0 + i as f64;
1278            if let Some(v) = rsi.update(price) {
1279                last_from_update = Some(v);
1280            }
1281        }
1282
1283        // value() must equal the last result returned by update()
1284        assert_eq!(
1285            rsi.value(),
1286            last_from_update,
1287            "value() must equal the last update() result"
1288        );
1289    }
1290
1291    #[test]
1292    fn test_rsi_reset_clears_value() {
1293        let mut rsi = RSI::new(14);
1294        for i in 0..30 {
1295            rsi.update(100.0 + i as f64);
1296        }
1297        assert!(rsi.value().is_some());
1298        rsi.reset();
1299        assert!(rsi.value().is_none(), "value() should be None after reset");
1300    }
1301
1302    // --- SMA Helper Test ---
1303
1304    #[test]
1305    fn test_calculate_sma() {
1306        assert_eq!(calculate_sma(&[1.0, 2.0, 3.0, 4.0, 5.0]), 3.0);
1307        assert_eq!(calculate_sma(&[100.0]), 100.0);
1308        assert_eq!(calculate_sma(&[]), 0.0);
1309    }
1310
1311    #[test]
1312    fn test_calculate_sma_precision() {
1313        let prices = vec![10.0, 20.0, 30.0];
1314        let sma = calculate_sma(&prices);
1315        assert!((sma - 20.0).abs() < f64::EPSILON);
1316    }
1317}