Skip to main content

indicators/regime/
primitives.rs

1//! Technical Indicators for Regime Detection
2//!
3//! Self-contained indicator implementations used by the regime detection system.
4//! Provides EMA, ATR, ADX, and Bollinger Bands calculations optimized for
5//! market regime classification.
6//!
7//! These are intentionally kept within the regime crate rather than depending on
8//! `indicators`, because:
9//! 1. The regime crate needs specific indicator semantics (e.g., ADX with DI crossover)
10//! 2. Keeps the crate self-contained with zero internal dependencies
11//! 3. `indicators` can later delegate to these if desired
12
13use std::collections::{HashMap, VecDeque};
14
15use super::types::TrendDirection;
16
17use crate::error::IndicatorError;
18use crate::indicator::{Indicator, IndicatorOutput};
19use crate::registry::param_usize;
20use crate::types::Candle;
21
22// ── Indicator wrappers ────────────────────────────────────────────────────────
23
24/// Batch `Indicator` wrapping the regime-internal [`ADX`] primitive.
25///
26/// Outputs `adx`, `di_plus`, and `di_minus` per bar.
27#[derive(Debug, Clone)]
28pub struct AdxIndicator {
29    pub period: usize,
30}
31
32impl AdxIndicator {
33    pub fn new(period: usize) -> Self {
34        Self { period }
35    }
36}
37
38impl Indicator for AdxIndicator {
39    fn name(&self) -> &'static str {
40        "ADX"
41    }
42    fn required_len(&self) -> usize {
43        self.period * 2
44    }
45    fn required_columns(&self) -> &[&'static str] {
46        &["high", "low", "close"]
47    }
48
49    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
50        self.check_len(candles)?;
51        let mut adx_calc = ADX::new(self.period);
52        let n = candles.len();
53        let mut adx_out = vec![f64::NAN; n];
54        let mut dip_out = vec![f64::NAN; n];
55        let mut dmi_out = vec![f64::NAN; n];
56        for (i, c) in candles.iter().enumerate() {
57            if let Some(v) = adx_calc.update(c.high, c.low, c.close) {
58                adx_out[i] = v;
59                dip_out[i] = adx_calc.di_plus().unwrap_or(f64::NAN);
60                dmi_out[i] = adx_calc.di_minus().unwrap_or(f64::NAN);
61            }
62        }
63        Ok(IndicatorOutput::from_pairs([
64            ("adx", adx_out),
65            ("di_plus", dip_out),
66            ("di_minus", dmi_out),
67        ]))
68    }
69}
70
71/// Batch `Indicator` wrapping the regime-internal [`ATR`] primitive.
72#[derive(Debug, Clone)]
73pub struct AtrPrimIndicator {
74    pub period: usize,
75}
76
77impl AtrPrimIndicator {
78    pub fn new(period: usize) -> Self {
79        Self { period }
80    }
81}
82
83impl Indicator for AtrPrimIndicator {
84    fn name(&self) -> &'static str {
85        "AtrPrim"
86    }
87    fn required_len(&self) -> usize {
88        self.period + 1
89    }
90    fn required_columns(&self) -> &[&'static str] {
91        &["high", "low", "close"]
92    }
93
94    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
95        self.check_len(candles)?;
96        let mut atr_calc = ATR::new(self.period);
97        let n = candles.len();
98        let mut out = vec![f64::NAN; n];
99        for (i, c) in candles.iter().enumerate() {
100            if let Some(v) = atr_calc.update(c.high, c.low, c.close) {
101                out[i] = v;
102            }
103        }
104        Ok(IndicatorOutput::from_pairs([("atr_prim", out)]))
105    }
106}
107
108/// Batch `Indicator` wrapping the regime-internal [`EMA`] primitive.
109#[derive(Debug, Clone)]
110pub struct EmaPrimIndicator {
111    pub period: usize,
112}
113
114impl EmaPrimIndicator {
115    pub fn new(period: usize) -> Self {
116        Self { period }
117    }
118}
119
120impl Indicator for EmaPrimIndicator {
121    fn name(&self) -> &'static str {
122        "EmaPrim"
123    }
124    fn required_len(&self) -> usize {
125        self.period
126    }
127    fn required_columns(&self) -> &[&'static str] {
128        &["close"]
129    }
130
131    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
132        self.check_len(candles)?;
133        let mut ema_calc = EMA::new(self.period);
134        let n = candles.len();
135        let mut out = vec![f64::NAN; n];
136        for (i, c) in candles.iter().enumerate() {
137            if let Some(v) = ema_calc.update(c.close) {
138                out[i] = v;
139            }
140        }
141        Ok(IndicatorOutput::from_pairs([("ema_prim", out)]))
142    }
143}
144
145/// Batch `Indicator` wrapping the regime-internal [`RSI`] primitive.
146#[derive(Debug, Clone)]
147pub struct RsiPrimIndicator {
148    pub period: usize,
149}
150
151impl RsiPrimIndicator {
152    pub fn new(period: usize) -> Self {
153        Self { period }
154    }
155}
156
157impl Indicator for RsiPrimIndicator {
158    fn name(&self) -> &'static str {
159        "RsiPrim"
160    }
161    fn required_len(&self) -> usize {
162        self.period + 1
163    }
164    fn required_columns(&self) -> &[&'static str] {
165        &["close"]
166    }
167
168    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
169        self.check_len(candles)?;
170        let mut rsi_calc = RSI::new(self.period);
171        let n = candles.len();
172        let mut out = vec![f64::NAN; n];
173        for (i, c) in candles.iter().enumerate() {
174            if let Some(v) = rsi_calc.update(c.close) {
175                out[i] = v;
176            }
177        }
178        Ok(IndicatorOutput::from_pairs([("rsi_prim", out)]))
179    }
180}
181
182/// Batch `Indicator` wrapping the regime-internal [`BollingerBands`] primitive.
183///
184/// Outputs `bb_upper`, `bb_mid`, `bb_lower`, and `bb_width` per bar.
185#[derive(Debug, Clone)]
186pub struct BbPrimIndicator {
187    pub period: usize,
188    pub std_dev: f64,
189}
190
191impl BbPrimIndicator {
192    pub fn new(period: usize, std_dev: f64) -> Self {
193        Self { period, std_dev }
194    }
195}
196
197impl Indicator for BbPrimIndicator {
198    fn name(&self) -> &'static str {
199        "BbPrim"
200    }
201    fn required_len(&self) -> usize {
202        self.period
203    }
204    fn required_columns(&self) -> &[&'static str] {
205        &["close"]
206    }
207
208    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
209        self.check_len(candles)?;
210        let mut bb = BollingerBands::new(self.period, self.std_dev);
211        let n = candles.len();
212        let mut upper = vec![f64::NAN; n];
213        let mut mid = vec![f64::NAN; n];
214        let mut lower = vec![f64::NAN; n];
215        let mut width = vec![f64::NAN; n];
216        for (i, c) in candles.iter().enumerate() {
217            if let Some(v) = bb.update(c.close) {
218                upper[i] = v.upper;
219                mid[i] = v.middle;
220                lower[i] = v.lower;
221                width[i] = v.width;
222            }
223        }
224        Ok(IndicatorOutput::from_pairs([
225            ("bb_upper", upper),
226            ("bb_mid", mid),
227            ("bb_lower", lower),
228            ("bb_width", width),
229        ]))
230    }
231}
232
233// ── Registry factory ──────────────────────────────────────────────────────────
234
235/// Default factory registers as `"primitives"` → produces [`AdxIndicator`].
236/// Use the individual wrapper structs directly for EMA, ATR, RSI, or BB.
237pub fn factory<S: ::std::hash::BuildHasher>(
238    params: &HashMap<String, String, S>,
239) -> Result<Box<dyn Indicator>, IndicatorError> {
240    let period = param_usize(params, "period", 14)?;
241    Ok(Box::new(AdxIndicator::new(period)))
242}
243
244// ============================================================================
245// Exponential Moving Average (EMA)
246// ============================================================================
247
248/// Exponential Moving Average calculator
249///
250/// Uses the standard EMA formula: EMA_t = price * k + EMA_{t-1} * (1 - k)
251/// where k = 2 / (period + 1)
252#[derive(Debug, Clone)]
253pub struct EMA {
254    period: usize,
255    multiplier: f64,
256    current_value: Option<f64>,
257    initialized: bool,
258    warmup_count: usize,
259}
260
261impl EMA {
262    /// Create a new EMA with the given period
263    pub fn new(period: usize) -> Self {
264        let multiplier = 2.0 / (period as f64 + 1.0);
265        Self {
266            period,
267            multiplier,
268            current_value: None,
269            initialized: false,
270            warmup_count: 0,
271        }
272    }
273
274    /// Update with a new price value, returning the EMA if warmed up
275    pub fn update(&mut self, price: f64) -> Option<f64> {
276        self.warmup_count += 1;
277
278        match self.current_value {
279            Some(prev_ema) => {
280                let new_ema = (price - prev_ema) * self.multiplier + prev_ema;
281                self.current_value = Some(new_ema);
282
283                if self.warmup_count >= self.period {
284                    self.initialized = true;
285                }
286            }
287            None => {
288                self.current_value = Some(price);
289            }
290        }
291
292        if self.initialized {
293            self.current_value
294        } else {
295            None
296        }
297    }
298
299    /// Get the current EMA value (None if not yet warmed up)
300    pub fn value(&self) -> Option<f64> {
301        if self.initialized {
302            self.current_value
303        } else {
304            None
305        }
306    }
307
308    /// Check if the EMA has enough data to produce valid values
309    pub fn is_ready(&self) -> bool {
310        self.initialized
311    }
312
313    /// Get the period
314    pub fn period(&self) -> usize {
315        self.period
316    }
317
318    /// Reset the EMA state
319    pub fn reset(&mut self) {
320        self.current_value = None;
321        self.initialized = false;
322        self.warmup_count = 0;
323    }
324}
325
326// ============================================================================
327// Average True Range (ATR)
328// ============================================================================
329
330/// Average True Range (ATR) calculator
331///
332/// Uses Wilder's smoothing method for the ATR calculation.
333/// True Range = max(High - Low, |High - PrevClose|, |Low - PrevClose|)
334#[derive(Debug, Clone)]
335pub struct ATR {
336    period: usize,
337    values: VecDeque<f64>,
338    prev_close: Option<f64>,
339    current_atr: Option<f64>,
340}
341
342impl ATR {
343    /// Create a new ATR with the given period
344    pub fn new(period: usize) -> Self {
345        Self {
346            period,
347            values: VecDeque::with_capacity(period),
348            prev_close: None,
349            current_atr: None,
350        }
351    }
352
353    /// Update with OHLC data, returning the ATR if warmed up
354    pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
355        let true_range = match self.prev_close {
356            Some(prev_c) => {
357                let hl = high - low;
358                let hc = (high - prev_c).abs();
359                let lc = (low - prev_c).abs();
360                hl.max(hc).max(lc)
361            }
362            None => high - low,
363        };
364
365        self.prev_close = Some(close);
366        self.values.push_back(true_range);
367
368        if self.values.len() > self.period {
369            self.values.pop_front();
370        }
371
372        if self.values.len() >= self.period {
373            // Use Wilder's smoothing method
374            if let Some(prev_atr) = self.current_atr {
375                let new_atr =
376                    (prev_atr * (self.period - 1) as f64 + true_range) / self.period as f64;
377                self.current_atr = Some(new_atr);
378            } else {
379                let sum: f64 = self.values.iter().sum();
380                self.current_atr = Some(sum / self.period as f64);
381            }
382        }
383
384        self.current_atr
385    }
386
387    /// Get the current ATR value
388    pub fn value(&self) -> Option<f64> {
389        self.current_atr
390    }
391
392    /// Check if the ATR has enough data
393    pub fn is_ready(&self) -> bool {
394        self.current_atr.is_some()
395    }
396
397    /// Get the period
398    pub fn period(&self) -> usize {
399        self.period
400    }
401
402    /// Reset the ATR state
403    pub fn reset(&mut self) {
404        self.values.clear();
405        self.prev_close = None;
406        self.current_atr = None;
407    }
408}
409
410// ============================================================================
411// Average Directional Index (ADX)
412// ============================================================================
413
414/// Average Directional Index (ADX) calculator
415///
416/// Measures trend strength (not direction). Values above 25 typically indicate
417/// a strong trend, while values below 20 suggest a ranging market.
418///
419/// Also provides +DI and -DI for trend direction via `trend_direction()`.
420#[derive(Debug, Clone)]
421pub struct ADX {
422    period: usize,
423    atr: ATR,
424    plus_dm_ema: EMA,
425    minus_dm_ema: EMA,
426    dx_values: VecDeque<f64>,
427    prev_high: Option<f64>,
428    prev_low: Option<f64>,
429    current_adx: Option<f64>,
430    plus_dir_index: Option<f64>,
431    minus_dir_index: Option<f64>,
432}
433
434impl ADX {
435    /// Create a new ADX with the given period
436    pub fn new(period: usize) -> Self {
437        Self {
438            period,
439            atr: ATR::new(period),
440            plus_dm_ema: EMA::new(period),
441            minus_dm_ema: EMA::new(period),
442            dx_values: VecDeque::with_capacity(period),
443            prev_high: None,
444            prev_low: None,
445            current_adx: None,
446            plus_dir_index: None,
447            minus_dir_index: None,
448        }
449    }
450
451    /// Update with HLC data, returning the ADX value if warmed up
452    pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
453        // Calculate directional movement
454        let (plus_dm, minus_dm) = match (self.prev_high, self.prev_low) {
455            (Some(prev_h), Some(prev_l)) => {
456                let up_move = high - prev_h;
457                let down_move = prev_l - low;
458
459                let plus = if up_move > down_move && up_move > 0.0 {
460                    up_move
461                } else {
462                    0.0
463                };
464
465                let minus = if down_move > up_move && down_move > 0.0 {
466                    down_move
467                } else {
468                    0.0
469                };
470
471                (plus, minus)
472            }
473            _ => (0.0, 0.0),
474        };
475
476        self.prev_high = Some(high);
477        self.prev_low = Some(low);
478
479        // Update ATR
480        let atr = self.atr.update(high, low, close);
481
482        // Smooth directional movement
483        let smoothed_plus_dm = self.plus_dm_ema.update(plus_dm);
484        let smoothed_minus_dm = self.minus_dm_ema.update(minus_dm);
485
486        // Calculate DI values
487        if let (Some(atr_val), Some(plus_dm_smooth), Some(minus_dm_smooth)) =
488            (atr, smoothed_plus_dm, smoothed_minus_dm)
489            && atr_val > 0.0
490        {
491            let plus_dir_index = (plus_dm_smooth / atr_val) * 100.0;
492            let minus_dir_index = (minus_dm_smooth / atr_val) * 100.0;
493            self.plus_dir_index = Some(plus_dir_index);
494            self.minus_dir_index = Some(minus_dir_index);
495
496            // Calculate DX
497            let di_sum = plus_dir_index + minus_dir_index;
498            if di_sum > 0.0 {
499                let di_diff = (plus_dir_index - minus_dir_index).abs();
500                let dx = (di_diff / di_sum) * 100.0;
501
502                self.dx_values.push_back(dx);
503                if self.dx_values.len() > self.period {
504                    self.dx_values.pop_front();
505                }
506
507                // Calculate ADX as smoothed DX
508                if self.dx_values.len() >= self.period {
509                    if let Some(prev_adx) = self.current_adx {
510                        let new_adx =
511                            (prev_adx * (self.period - 1) as f64 + dx) / self.period as f64;
512                        self.current_adx = Some(new_adx);
513                    } else {
514                        let sum: f64 = self.dx_values.iter().sum();
515                        self.current_adx = Some(sum / self.period as f64);
516                    }
517                }
518            }
519        }
520
521        self.current_adx
522    }
523
524    /// Get the current ADX value
525    pub fn value(&self) -> Option<f64> {
526        self.current_adx
527    }
528
529    /// Get the +DI value
530    pub fn plus_dir_index(&self) -> Option<f64> {
531        self.plus_dir_index
532    }
533
534    /// Get the -DI value
535    pub fn minus_dir_index(&self) -> Option<f64> {
536        self.minus_dir_index
537    }
538
539    /// Returns trend direction based on DI crossover.
540    ///
541    /// - `+DI > -DI` → Bullish
542    /// - `-DI > +DI` → Bearish
543    pub fn trend_direction(&self) -> Option<TrendDirection> {
544        match (self.plus_dir_index, self.minus_dir_index) {
545            (Some(plus), Some(minus)) => {
546                if plus > minus {
547                    Some(TrendDirection::Bullish)
548                } else {
549                    Some(TrendDirection::Bearish)
550                }
551            }
552            _ => None,
553        }
554    }
555
556    /// Check if the ADX has enough data
557    pub fn is_ready(&self) -> bool {
558        self.current_adx.is_some()
559    }
560
561    /// Get the period
562    pub fn period(&self) -> usize {
563        self.period
564    }
565
566    /// Current DI+ value (directional index plus), available after warm-up.
567    pub fn di_plus(&self) -> Option<f64> {
568        self.plus_dir_index
569    }
570
571    /// Current DI- value (directional index minus), available after warm-up.
572    pub fn di_minus(&self) -> Option<f64> {
573        self.minus_dir_index
574    }
575
576    /// Reset the ADX state
577    pub fn reset(&mut self) {
578        self.atr.reset();
579        self.plus_dm_ema.reset();
580        self.minus_dm_ema.reset();
581        self.dx_values.clear();
582        self.prev_high = None;
583        self.prev_low = None;
584        self.current_adx = None;
585        self.plus_dir_index = None;
586        self.minus_dir_index = None;
587    }
588}
589
590// ============================================================================
591// Bollinger Bands
592// ============================================================================
593
594/// Bollinger Bands output values
595#[derive(Debug, Clone, Copy)]
596pub struct BollingerBandsValues {
597    /// Upper band (SMA + n * σ)
598    pub upper: f64,
599    /// Middle band (SMA)
600    pub middle: f64,
601    /// Lower band (SMA - n * σ)
602    pub lower: f64,
603    /// Band width as percentage of price
604    pub width: f64,
605    /// Where current width ranks historically (0–100 percentile)
606    pub width_percentile: f64,
607    /// Where price is within the bands (0.0 = lower, 1.0 = upper)
608    pub percent_b: f64,
609    /// Standard deviation of prices
610    pub std_dev: f64,
611}
612
613impl BollingerBandsValues {
614    /// Is price overbought (near or above upper band)?
615    pub fn is_overbought(&self) -> bool {
616        self.percent_b >= 0.95
617    }
618
619    /// Is price oversold (near or below lower band)?
620    pub fn is_oversold(&self) -> bool {
621        self.percent_b <= 0.05
622    }
623
624    /// Are bands wide (high volatility)?
625    pub fn is_high_volatility(&self, threshold_percentile: f64) -> bool {
626        self.width_percentile >= threshold_percentile
627    }
628
629    /// Are bands narrow (potential breakout coming)?
630    pub fn is_squeeze(&self, threshold_percentile: f64) -> bool {
631        self.width_percentile <= threshold_percentile
632    }
633}
634
635/// Bollinger Bands calculator
636///
637/// Computes upper, lower, and middle bands along with band width percentile
638/// for volatility regime classification.
639#[derive(Debug, Clone)]
640pub struct BollingerBands {
641    period: usize,
642    std_dev_multiplier: f64,
643    prices: VecDeque<f64>,
644    width_history: VecDeque<f64>,
645    width_history_size: usize,
646}
647
648impl BollingerBands {
649    /// Create a new Bollinger Bands calculator
650    ///
651    /// # Arguments
652    /// * `period` - Lookback period for the SMA (typically 20)
653    /// * `std_dev_multiplier` - Standard deviation multiplier (typically 2.0)
654    pub fn new(period: usize, std_dev_multiplier: f64) -> Self {
655        Self {
656            period,
657            std_dev_multiplier,
658            prices: VecDeque::with_capacity(period),
659            width_history: VecDeque::with_capacity(100),
660            width_history_size: 100, // Keep 100 periods for percentile calc
661        }
662    }
663
664    /// Update with a new price, returning band values if warmed up
665    pub fn update(&mut self, price: f64) -> Option<BollingerBandsValues> {
666        self.prices.push_back(price);
667        if self.prices.len() > self.period {
668            self.prices.pop_front();
669        }
670
671        if self.prices.len() < self.period {
672            return None;
673        }
674
675        // Calculate SMA (middle band)
676        let sum: f64 = self.prices.iter().sum();
677        let sma = sum / self.period as f64;
678
679        // Calculate standard deviation
680        let variance: f64 =
681            self.prices.iter().map(|p| (p - sma).powi(2)).sum::<f64>() / self.period as f64;
682        let std_dev = variance.sqrt();
683
684        // Calculate bands
685        let upper = sma + (std_dev * self.std_dev_multiplier);
686        let lower = sma - (std_dev * self.std_dev_multiplier);
687        let width = if sma > 0.0 {
688            (upper - lower) / sma * 100.0 // Width as percentage of price
689        } else {
690            0.0
691        };
692
693        // Update width history for percentile calculation
694        self.width_history.push_back(width);
695        if self.width_history.len() > self.width_history_size {
696            self.width_history.pop_front();
697        }
698
699        // Calculate width percentile
700        let width_percentile = self.calculate_width_percentile(width);
701
702        // Calculate %B (where price is within bands)
703        let percent_b = if upper - lower > 0.0 {
704            (price - lower) / (upper - lower)
705        } else {
706            0.5
707        };
708
709        Some(BollingerBandsValues {
710            upper,
711            middle: sma,
712            lower,
713            width,
714            width_percentile,
715            percent_b,
716            std_dev,
717        })
718    }
719
720    /// Calculate where the current width ranks in recent history
721    fn calculate_width_percentile(&self, current_width: f64) -> f64 {
722        if self.width_history.len() < 10 {
723            return 50.0; // Not enough data
724        }
725
726        let count_below = self
727            .width_history
728            .iter()
729            .filter(|&&w| w < current_width)
730            .count();
731
732        (count_below as f64 / self.width_history.len() as f64) * 100.0
733    }
734
735    /// Check if the Bollinger Bands have enough data
736    pub fn is_ready(&self) -> bool {
737        self.prices.len() >= self.period
738    }
739
740    /// Get the period
741    pub fn period(&self) -> usize {
742        self.period
743    }
744
745    /// Get the standard deviation multiplier
746    pub fn std_dev_multiplier(&self) -> f64 {
747        self.std_dev_multiplier
748    }
749
750    /// Reset the Bollinger Bands state
751    pub fn reset(&mut self) {
752        self.prices.clear();
753        self.width_history.clear();
754    }
755}
756
757// ============================================================================
758// RSI (Relative Strength Index)
759// ============================================================================
760
761/// Relative Strength Index (RSI) calculator
762///
763/// Uses EMA-smoothed gains and losses for a responsive RSI calculation.
764/// Values above 70 indicate overbought, below 30 indicate oversold.
765#[derive(Debug, Clone)]
766pub struct RSI {
767    period: usize,
768    gains: EMA,
769    losses: EMA,
770    prev_close: Option<f64>,
771    last_rsi: Option<f64>,
772}
773
774impl RSI {
775    /// Create a new RSI with the given period (typically 14)
776    pub fn new(period: usize) -> Self {
777        Self {
778            period,
779            gains: EMA::new(period),
780            losses: EMA::new(period),
781            prev_close: None,
782            last_rsi: None,
783        }
784    }
785
786    /// Update with a new close price, returning the RSI if warmed up
787    pub fn update(&mut self, close: f64) -> Option<f64> {
788        if let Some(prev) = self.prev_close {
789            let change = close - prev;
790            let gain = if change > 0.0 { change } else { 0.0 };
791            let loss = if change < 0.0 { -change } else { 0.0 };
792
793            if let (Some(avg_gain), Some(avg_loss)) =
794                (self.gains.update(gain), self.losses.update(loss))
795            {
796                self.prev_close = Some(close);
797
798                let rsi = if avg_loss == 0.0 {
799                    100.0
800                } else {
801                    let rs = avg_gain / avg_loss;
802                    100.0 - (100.0 / (1.0 + rs))
803                };
804                self.last_rsi = Some(rsi);
805                return self.last_rsi;
806            }
807        }
808
809        self.prev_close = Some(close);
810        None
811    }
812
813    /// Get the most recent RSI value without consuming a new price tick.
814    ///
815    /// Returns `None` until the indicator has completed its warm-up period.
816    pub fn value(&self) -> Option<f64> {
817        self.last_rsi
818    }
819
820    /// Check if RSI has enough data
821    pub fn is_ready(&self) -> bool {
822        self.gains.is_ready() && self.losses.is_ready()
823    }
824
825    /// Get the period
826    pub fn period(&self) -> usize {
827        self.period
828    }
829
830    /// Reset the RSI state
831    pub fn reset(&mut self) {
832        self.gains.reset();
833        self.losses.reset();
834        self.prev_close = None;
835        self.last_rsi = None;
836    }
837}
838
839// ============================================================================
840// Helper Functions
841// ============================================================================
842
843/// Calculate a Simple Moving Average from a slice of values
844pub fn calculate_sma(prices: &[f64]) -> f64 {
845    if prices.is_empty() {
846        return 0.0;
847    }
848    prices.iter().sum::<f64>() / prices.len() as f64
849}
850
851// ============================================================================
852// Tests
853// ============================================================================
854
855#[cfg(test)]
856mod tests {
857    use super::*;
858
859    // --- EMA Tests ---
860
861    #[test]
862    fn test_ema_creation() {
863        let ema = EMA::new(10);
864        assert_eq!(ema.period(), 10);
865        assert!(!ema.is_ready());
866        assert!(ema.value().is_none());
867    }
868
869    #[test]
870    fn test_ema_warmup() {
871        let mut ema = EMA::new(10);
872
873        // Should return None during warmup
874        for i in 1..10 {
875            let result = ema.update(i as f64 * 10.0);
876            assert!(result.is_none(), "Should be None during warmup at step {i}");
877        }
878
879        // Should return Some after warmup
880        let result = ema.update(100.0);
881        assert!(result.is_some(), "Should be ready after {0} updates", 10);
882        assert!(ema.is_ready());
883    }
884
885    #[test]
886    fn test_ema_calculation() {
887        let mut ema = EMA::new(10);
888
889        // Warm up
890        for i in 1..=10 {
891            ema.update(i as f64 * 10.0);
892        }
893
894        assert!(ema.is_ready());
895        let value = ema.value().unwrap();
896        // EMA should be between the min and max input values
897        assert!(value > 10.0 && value <= 100.0);
898    }
899
900    #[test]
901    fn test_ema_tracks_trend() {
902        let mut ema = EMA::new(5);
903
904        // Warm up with constant price
905        for _ in 0..5 {
906            ema.update(100.0);
907        }
908        let stable = ema.value().unwrap();
909
910        // Feed higher prices
911        for _ in 0..10 {
912            ema.update(110.0);
913        }
914        let after_up = ema.value().unwrap();
915
916        assert!(after_up > stable, "EMA should increase with rising prices");
917    }
918
919    #[test]
920    fn test_ema_reset() {
921        let mut ema = EMA::new(5);
922        for _ in 0..10 {
923            ema.update(100.0);
924        }
925        assert!(ema.is_ready());
926
927        ema.reset();
928        assert!(!ema.is_ready());
929        assert!(ema.value().is_none());
930    }
931
932    // --- ATR Tests ---
933
934    #[test]
935    fn test_atr_creation() {
936        let atr = ATR::new(14);
937        assert_eq!(atr.period(), 14);
938        assert!(!atr.is_ready());
939    }
940
941    #[test]
942    fn test_atr_warmup() {
943        let mut atr = ATR::new(14);
944
945        for i in 1..=14 {
946            let base = 100.0 + i as f64;
947            let result = atr.update(base + 1.0, base - 1.0, base);
948            if i < 14 {
949                assert!(result.is_none());
950            }
951        }
952
953        assert!(atr.is_ready());
954    }
955
956    #[test]
957    fn test_atr_increases_with_volatility() {
958        let mut atr = ATR::new(14);
959
960        // Low volatility warmup
961        for i in 1..=14 {
962            let base = 100.0 + i as f64 * 0.1;
963            atr.update(base + 0.5, base - 0.5, base);
964        }
965        let low_vol_atr = atr.value().unwrap();
966
967        // High volatility bars
968        for i in 0..20 {
969            let base = 100.0 + if i % 2 == 0 { 5.0 } else { -5.0 };
970            atr.update(base + 3.0, base - 3.0, base);
971        }
972        let high_vol_atr = atr.value().unwrap();
973
974        assert!(
975            high_vol_atr > low_vol_atr,
976            "ATR should increase with volatility: {high_vol_atr} vs {low_vol_atr}"
977        );
978    }
979
980    #[test]
981    fn test_atr_reset() {
982        let mut atr = ATR::new(14);
983        for i in 0..20 {
984            let base = 100.0 + i as f64;
985            atr.update(base + 1.0, base - 1.0, base);
986        }
987        assert!(atr.is_ready());
988
989        atr.reset();
990        assert!(!atr.is_ready());
991        assert!(atr.value().is_none());
992    }
993
994    // --- ADX Tests ---
995
996    #[test]
997    fn test_adx_creation() {
998        let adx = ADX::new(14);
999        assert_eq!(adx.period(), 14);
1000        assert!(!adx.is_ready());
1001    }
1002
1003    #[test]
1004    fn test_adx_trending_detection() {
1005        let mut adx = ADX::new(14);
1006
1007        // Simulate strong uptrend (prices going up steadily)
1008        for i in 1..=50 {
1009            let high = 100.0 + i as f64 * 2.0;
1010            let low = 100.0 + i as f64 * 2.0 - 1.0;
1011            let close = 100.0 + i as f64 * 2.0 - 0.5;
1012            adx.update(high, low, close);
1013        }
1014
1015        if let Some(adx_value) = adx.value() {
1016            assert!(
1017                adx_value > 20.0,
1018                "ADX should indicate trend in strong uptrend: {adx_value}"
1019            );
1020        }
1021    }
1022
1023    #[test]
1024    fn test_adx_trend_direction() {
1025        let mut adx = ADX::new(14);
1026
1027        // Strong uptrend
1028        for i in 1..=50 {
1029            let high = 100.0 + i as f64 * 2.0;
1030            let low = 100.0 + i as f64 * 2.0 - 1.0;
1031            let close = 100.0 + i as f64 * 2.0 - 0.5;
1032            adx.update(high, low, close);
1033        }
1034
1035        if let Some(dir) = adx.trend_direction() {
1036            assert_eq!(
1037                dir,
1038                TrendDirection::Bullish,
1039                "Should detect bullish direction in uptrend"
1040            );
1041        }
1042    }
1043
1044    #[test]
1045    fn test_adx_di_values() {
1046        let mut adx = ADX::new(14);
1047
1048        for i in 1..=50 {
1049            let high = 100.0 + i as f64 * 2.0;
1050            let low = 100.0 + i as f64 * 2.0 - 1.0;
1051            let close = 100.0 + i as f64 * 2.0 - 0.5;
1052            adx.update(high, low, close);
1053        }
1054
1055        // In an uptrend, +DI should be higher than -DI
1056        if let (Some(plus), Some(minus)) = (adx.plus_dir_index(), adx.minus_dir_index()) {
1057            assert!(
1058                plus > minus,
1059                "+DI ({plus}) should be > -DI ({minus}) in uptrend"
1060            );
1061        }
1062    }
1063
1064    #[test]
1065    fn test_adx_reset() {
1066        let mut adx = ADX::new(14);
1067        for i in 1..=50 {
1068            let base = 100.0 + i as f64;
1069            adx.update(base + 1.0, base - 1.0, base);
1070        }
1071        assert!(adx.is_ready());
1072
1073        adx.reset();
1074        assert!(!adx.is_ready());
1075        assert!(adx.value().is_none());
1076        assert!(adx.plus_dir_index().is_none());
1077        assert!(adx.minus_dir_index().is_none());
1078    }
1079
1080    // --- Bollinger Bands Tests ---
1081
1082    #[test]
1083    fn test_bb_creation() {
1084        let bb = BollingerBands::new(20, 2.0);
1085        assert_eq!(bb.period(), 20);
1086        assert_eq!(bb.std_dev_multiplier(), 2.0);
1087        assert!(!bb.is_ready());
1088    }
1089
1090    #[test]
1091    fn test_bb_warmup() {
1092        let mut bb = BollingerBands::new(20, 2.0);
1093
1094        for i in 1..20 {
1095            let result = bb.update(100.0 + i as f64 * 0.1);
1096            assert!(result.is_none());
1097        }
1098
1099        let result = bb.update(102.0);
1100        assert!(result.is_some());
1101        assert!(bb.is_ready());
1102    }
1103
1104    #[test]
1105    fn test_bb_band_ordering() {
1106        let mut bb = BollingerBands::new(20, 2.0);
1107
1108        for i in 1..=25 {
1109            let price = 100.0 + (i as f64 % 5.0);
1110            bb.update(price);
1111        }
1112
1113        let result = bb.update(102.0).unwrap();
1114        assert!(
1115            result.upper > result.middle,
1116            "Upper band ({}) should be > middle ({})",
1117            result.upper,
1118            result.middle
1119        );
1120        assert!(
1121            result.middle > result.lower,
1122            "Middle ({}) should be > lower ({})",
1123            result.middle,
1124            result.lower
1125        );
1126    }
1127
1128    #[test]
1129    fn test_bb_percent_b() {
1130        let mut bb = BollingerBands::new(20, 2.0);
1131
1132        // Build some history with variance
1133        for i in 1..=20 {
1134            bb.update(100.0 + (i as f64 % 3.0));
1135        }
1136
1137        // Price at middle should give %B near 0.5
1138        let values = bb.update(100.0 + 1.0);
1139        if let Some(v) = values {
1140            // %B should be between 0 and 1 for normal prices
1141            assert!(
1142                v.percent_b >= 0.0 && v.percent_b <= 1.0,
1143                "%B should be in [0,1]: {}",
1144                v.percent_b
1145            );
1146        }
1147    }
1148
1149    #[test]
1150    fn test_bb_squeeze_detection() {
1151        let mut bb = BollingerBands::new(20, 2.0);
1152
1153        // First, create wide bands with volatile data
1154        for i in 0..50 {
1155            let price = 100.0 + if i % 2 == 0 { 10.0 } else { -10.0 };
1156            bb.update(price);
1157        }
1158
1159        // Then tighten with constant price
1160        for _ in 0..50 {
1161            bb.update(100.0);
1162        }
1163
1164        let result = bb.update(100.0).unwrap();
1165        // After constant prices, width percentile should be low
1166        assert!(
1167            result.width_percentile < 50.0,
1168            "Constant prices should produce low width percentile: {}",
1169            result.width_percentile
1170        );
1171    }
1172
1173    #[test]
1174    fn test_bb_overbought_oversold() {
1175        let mut bb = BollingerBands::new(20, 2.0);
1176
1177        // Build history around 100
1178        for _ in 0..20 {
1179            bb.update(100.0);
1180        }
1181
1182        // Price far above should be overbought
1183        let result = bb.update(110.0).unwrap();
1184        assert!(
1185            result.is_overbought(),
1186            "Price far above bands should be overbought, %B = {}",
1187            result.percent_b
1188        );
1189    }
1190
1191    #[test]
1192    fn test_bb_reset() {
1193        let mut bb = BollingerBands::new(20, 2.0);
1194        for i in 0..25 {
1195            bb.update(100.0 + i as f64);
1196        }
1197        assert!(bb.is_ready());
1198
1199        bb.reset();
1200        assert!(!bb.is_ready());
1201    }
1202
1203    // --- RSI Tests ---
1204
1205    #[test]
1206    fn test_rsi_creation() {
1207        let rsi = RSI::new(14);
1208        assert_eq!(rsi.period(), 14);
1209        assert!(!rsi.is_ready());
1210    }
1211
1212    #[test]
1213    fn test_rsi_bullish_market() {
1214        let mut rsi = RSI::new(14);
1215
1216        // Consistently rising prices
1217        let mut last_rsi = None;
1218        for i in 0..30 {
1219            let price = 100.0 + i as f64;
1220            if let Some(val) = rsi.update(price) {
1221                last_rsi = Some(val);
1222            }
1223        }
1224
1225        if let Some(val) = last_rsi {
1226            assert!(
1227                val > 50.0,
1228                "RSI should be above 50 in bullish market: {val}"
1229            );
1230        }
1231    }
1232
1233    #[test]
1234    fn test_rsi_bearish_market() {
1235        let mut rsi = RSI::new(14);
1236
1237        // Consistently falling prices
1238        let mut last_rsi = None;
1239        for i in 0..30 {
1240            let price = 200.0 - i as f64;
1241            if let Some(val) = rsi.update(price) {
1242                last_rsi = Some(val);
1243            }
1244        }
1245
1246        if let Some(val) = last_rsi {
1247            assert!(
1248                val < 50.0,
1249                "RSI should be below 50 in bearish market: {val}"
1250            );
1251        }
1252    }
1253
1254    #[test]
1255    fn test_rsi_range() {
1256        let mut rsi = RSI::new(14);
1257
1258        for i in 0..50 {
1259            let price = 100.0 + (i as f64 * 0.7).sin() * 10.0;
1260            if let Some(val) = rsi.update(price) {
1261                assert!(
1262                    (0.0..=100.0).contains(&val),
1263                    "RSI should be in [0, 100]: {val}"
1264                );
1265            }
1266        }
1267    }
1268
1269    #[test]
1270    fn test_rsi_value_cached() {
1271        let mut rsi = RSI::new(14);
1272        assert!(
1273            rsi.value().is_none(),
1274            "value() should be None before warmup"
1275        );
1276
1277        let mut last_from_update = None;
1278        for i in 0..30 {
1279            let price = 100.0 + i as f64;
1280            if let Some(v) = rsi.update(price) {
1281                last_from_update = Some(v);
1282            }
1283        }
1284
1285        // value() must equal the last result returned by update()
1286        assert_eq!(
1287            rsi.value(),
1288            last_from_update,
1289            "value() must equal the last update() result"
1290        );
1291    }
1292
1293    #[test]
1294    fn test_rsi_reset_clears_value() {
1295        let mut rsi = RSI::new(14);
1296        for i in 0..30 {
1297            rsi.update(100.0 + i as f64);
1298        }
1299        assert!(rsi.value().is_some());
1300        rsi.reset();
1301        assert!(rsi.value().is_none(), "value() should be None after reset");
1302    }
1303
1304    // --- SMA Helper Test ---
1305
1306    #[test]
1307    fn test_calculate_sma() {
1308        assert_eq!(calculate_sma(&[1.0, 2.0, 3.0, 4.0, 5.0]), 3.0);
1309        assert_eq!(calculate_sma(&[100.0]), 100.0);
1310        assert_eq!(calculate_sma(&[]), 0.0);
1311    }
1312
1313    #[test]
1314    fn test_calculate_sma_precision() {
1315        let prices = vec![10.0, 20.0, 30.0];
1316        let sma = calculate_sma(&prices);
1317        assert!((sma - 20.0).abs() < f64::EPSILON);
1318    }
1319}