indexes_rs/v2/adx/
main.rs

1use crate::v2::adx::types::{ADXConfig, ADXError, ADXInput, ADXOutput, ADXPeriodData, ADXState, TrendDirection, TrendStrength};
2
3/// Average Directional Index (ADX) Indicator
4///
5/// ADX measures trend strength regardless of direction. It's composed of:
6/// - True Range (TR): Max of (H-L), (H-Cp), (Cp-L)
7/// - Directional Movement: +DM = H-Hp (if positive), -DM = Lp-L (if positive)
8/// - Directional Indicators: +DI = (+DM smoothed / TR smoothed) * 100
9/// - Directional Index: DX = (|+DI - -DI| / (+DI + -DI)) * 100
10/// - ADX = Smoothed average of DX values
11///
12/// ADX interpretation:
13/// - 0-25: Weak trend or ranging market
14/// - 25-50: Strong trend
15/// - 50+: Very strong trend
16pub struct ADX {
17    state: ADXState,
18}
19
20impl ADX {
21    /// Create a new ADX calculator with default configuration (period=14)
22    pub fn new() -> Self {
23        Self::with_config(ADXConfig::default())
24    }
25
26    /// Create a new ADX calculator with custom period
27    pub fn with_period(period: usize) -> Result<Self, ADXError> {
28        if period == 0 {
29            return Err(ADXError::InvalidPeriod);
30        }
31
32        let config = ADXConfig {
33            period,
34            adx_smoothing: period,
35            ..Default::default()
36        };
37        Ok(Self::with_config(config))
38    }
39
40    /// Create a new ADX calculator with custom period and ADX smoothing
41    pub fn with_periods(period: usize, adx_smoothing: usize) -> Result<Self, ADXError> {
42        if period == 0 || adx_smoothing == 0 {
43            return Err(ADXError::InvalidPeriod);
44        }
45
46        let config = ADXConfig {
47            period,
48            adx_smoothing,
49            ..Default::default()
50        };
51        Ok(Self::with_config(config))
52    }
53
54    /// Create a new ADX calculator with custom configuration
55    pub fn with_config(config: ADXConfig) -> Self {
56        Self { state: ADXState::new(config) }
57    }
58
59    /// Calculate ADX for the given input
60    pub fn calculate(&mut self, input: ADXInput) -> Result<ADXOutput, ADXError> {
61        // Validate input
62        self.validate_input(&input)?;
63        self.validate_config()?;
64
65        if self.state.is_first {
66            self.handle_first_calculation(input)
67        } else {
68            self.handle_normal_calculation(input)
69        }
70    }
71
72    /// Calculate ADX for a batch of inputs
73    pub fn calculate_batch(&mut self, inputs: &[ADXInput]) -> Result<Vec<ADXOutput>, ADXError> {
74        inputs.iter().map(|input| self.calculate(*input)).collect()
75    }
76
77    /// Reset the calculator state
78    pub fn reset(&mut self) {
79        self.state = ADXState::new(self.state.config);
80    }
81
82    /// Get current state (for serialization/debugging)
83    pub fn get_state(&self) -> &ADXState {
84        &self.state
85    }
86
87    /// Restore state (for deserialization)
88    pub fn set_state(&mut self, state: ADXState) {
89        self.state = state;
90    }
91
92    /// Get current trend strength
93    pub fn trend_strength(&self) -> TrendStrength {
94        if let Some(adx) = self.state.current_adx {
95            self.classify_trend_strength(adx)
96        } else {
97            TrendStrength::Insufficient
98        }
99    }
100
101    /// Get current trend direction
102    pub fn trend_direction(&self) -> Option<TrendDirection> {
103        self.state
104            .period_data
105            .back()
106            .map(|last_data| self.determine_trend_direction(last_data.plus_di, last_data.minus_di))
107    }
108
109    // Private helper methods
110
111    fn validate_input(&self, input: &ADXInput) -> Result<(), ADXError> {
112        // Check for valid prices
113        if !input.high.is_finite() || !input.low.is_finite() || !input.close.is_finite() {
114            return Err(ADXError::InvalidPrice);
115        }
116
117        // Check HLC relationship
118        if input.high < input.low {
119            return Err(ADXError::InvalidHLC);
120        }
121
122        if input.close < input.low || input.close > input.high {
123            return Err(ADXError::InvalidHLC);
124        }
125
126        Ok(())
127    }
128
129    fn validate_config(&self) -> Result<(), ADXError> {
130        if self.state.config.period == 0 || self.state.config.adx_smoothing == 0 {
131            return Err(ADXError::InvalidPeriod);
132        }
133
134        if self.state.config.strong_trend_threshold >= self.state.config.very_strong_trend_threshold {
135            return Err(ADXError::InvalidThresholds);
136        }
137
138        Ok(())
139    }
140
141    fn handle_first_calculation(&mut self, input: ADXInput) -> Result<ADXOutput, ADXError> {
142        // First calculation - just store data
143        self.state.previous_high = Some(input.high);
144        self.state.previous_low = Some(input.low);
145        self.state.previous_close = Some(input.close);
146        self.state.is_first = false;
147
148        // Return default values for first calculation
149        Ok(ADXOutput {
150            adx: 0.0,
151            plus_di: 0.0,
152            minus_di: 0.0,
153            dx: 0.0,
154            true_range: 0.0,
155            trend_strength: TrendStrength::Insufficient,
156            trend_direction: TrendDirection::Sideways,
157            di_spread: 0.0,
158        })
159    }
160
161    fn handle_normal_calculation(&mut self, input: ADXInput) -> Result<ADXOutput, ADXError> {
162        // Calculate True Range
163        let true_range = self.calculate_true_range(&input);
164
165        // Calculate Directional Movements
166        let (plus_dm, minus_dm) = self.calculate_directional_movements(&input);
167
168        // Update or initialize smoothed values
169        if self.state.period_data.len() < self.state.config.period {
170            // Not enough data for smoothing yet - accumulate
171            self.accumulate_initial_data(true_range, plus_dm, minus_dm);
172        } else {
173            // Use smoothing formula
174            self.update_smoothed_values(true_range, plus_dm, minus_dm);
175        }
176
177        // Calculate DI values
178        let (plus_di, minus_di) = self.calculate_directional_indicators();
179
180        // Calculate DX
181        let dx = self.calculate_dx(plus_di, minus_di)?;
182
183        // Calculate ADX
184        let adx = self.calculate_adx(dx);
185
186        // Create period data
187        let period_data = ADXPeriodData {
188            true_range,
189            plus_dm,
190            minus_dm,
191            plus_di,
192            minus_di,
193            dx,
194        };
195
196        // Store period data
197        if self.state.period_data.len() >= self.state.config.period {
198            self.state.period_data.pop_front();
199        }
200        self.state.period_data.push_back(period_data);
201
202        // Update state
203        self.state.previous_high = Some(input.high);
204        self.state.previous_low = Some(input.low);
205        self.state.previous_close = Some(input.close);
206
207        // Determine outputs
208        let trend_strength = self.classify_trend_strength(adx);
209        let trend_direction = self.determine_trend_direction(plus_di, minus_di);
210        let di_spread = plus_di - minus_di;
211
212        Ok(ADXOutput {
213            adx,
214            plus_di,
215            minus_di,
216            dx,
217            true_range,
218            trend_strength,
219            trend_direction,
220            di_spread,
221        })
222    }
223
224    fn calculate_true_range(&self, input: &ADXInput) -> f64 {
225        if let Some(prev_close) = self.state.previous_close {
226            let hl = input.high - input.low;
227            let hc = (input.high - prev_close).abs();
228            let lc = (input.low - prev_close).abs();
229            hl.max(hc).max(lc)
230        } else {
231            input.high - input.low
232        }
233    }
234
235    fn calculate_directional_movements(&self, input: &ADXInput) -> (f64, f64) {
236        if let (Some(prev_high), Some(prev_low)) = (self.state.previous_high, self.state.previous_low) {
237            let up_move = input.high - prev_high;
238            let down_move = prev_low - input.low;
239
240            let plus_dm = if up_move > down_move && up_move > 0.0 { up_move } else { 0.0 };
241
242            let minus_dm = if down_move > up_move && down_move > 0.0 { down_move } else { 0.0 };
243
244            (plus_dm, minus_dm)
245        } else {
246            (0.0, 0.0)
247        }
248    }
249
250    fn accumulate_initial_data(&mut self, true_range: f64, plus_dm: f64, minus_dm: f64) {
251        // For the first period values, we sum them up
252        match self.state.smoothed_tr {
253            Some(tr) => self.state.smoothed_tr = Some(tr + true_range),
254            None => self.state.smoothed_tr = Some(true_range),
255        }
256
257        match self.state.smoothed_plus_dm {
258            Some(dm) => self.state.smoothed_plus_dm = Some(dm + plus_dm),
259            None => self.state.smoothed_plus_dm = Some(plus_dm),
260        }
261
262        match self.state.smoothed_minus_dm {
263            Some(dm) => self.state.smoothed_minus_dm = Some(dm + minus_dm),
264            None => self.state.smoothed_minus_dm = Some(minus_dm),
265        }
266
267        // Check if we have enough data for DI calculation
268        if self.state.period_data.len() + 1 >= self.state.config.period {
269            self.state.has_di_data = true;
270        }
271    }
272
273    fn update_smoothed_values(&mut self, true_range: f64, plus_dm: f64, minus_dm: f64) {
274        let period = self.state.config.period as f64;
275
276        // Wilder's smoothing: New = (Old * (n-1) + Current) / n
277        if let Some(smoothed_tr) = self.state.smoothed_tr {
278            self.state.smoothed_tr = Some((smoothed_tr * (period - 1.0) + true_range) / period);
279        }
280
281        if let Some(smoothed_plus_dm) = self.state.smoothed_plus_dm {
282            self.state.smoothed_plus_dm = Some((smoothed_plus_dm * (period - 1.0) + plus_dm) / period);
283        }
284
285        if let Some(smoothed_minus_dm) = self.state.smoothed_minus_dm {
286            self.state.smoothed_minus_dm = Some((smoothed_minus_dm * (period - 1.0) + minus_dm) / period);
287        }
288    }
289
290    fn calculate_directional_indicators(&self) -> (f64, f64) {
291        if let (Some(smoothed_tr), Some(smoothed_plus_dm), Some(smoothed_minus_dm)) = (self.state.smoothed_tr, self.state.smoothed_plus_dm, self.state.smoothed_minus_dm) {
292            if smoothed_tr != 0.0 {
293                let plus_di = (smoothed_plus_dm / smoothed_tr) * 100.0;
294                let minus_di = (smoothed_minus_dm / smoothed_tr) * 100.0;
295                (plus_di, minus_di)
296            } else {
297                (0.0, 0.0)
298            }
299        } else {
300            (0.0, 0.0)
301        }
302    }
303
304    fn calculate_dx(&self, plus_di: f64, minus_di: f64) -> Result<f64, ADXError> {
305        let di_sum = plus_di + minus_di;
306        if di_sum == 0.0 {
307            Ok(0.0)
308        } else {
309            let di_diff = (plus_di - minus_di).abs();
310            Ok((di_diff / di_sum) * 100.0)
311        }
312    }
313
314    fn calculate_adx(&mut self, dx: f64) -> f64 {
315        // Add DX to history
316        if self.state.dx_history.len() >= self.state.config.adx_smoothing {
317            self.state.dx_history.pop_front();
318        }
319        self.state.dx_history.push_back(dx);
320
321        // Calculate ADX
322        if self.state.dx_history.len() >= self.state.config.adx_smoothing {
323            if !self.state.has_adx_data {
324                // First ADX calculation - simple average
325                let adx = self.state.dx_history.iter().sum::<f64>() / self.state.dx_history.len() as f64;
326                self.state.current_adx = Some(adx);
327                self.state.has_adx_data = true;
328                adx
329            } else {
330                // Subsequent ADX calculations - use smoothing
331                if let Some(prev_adx) = self.state.current_adx {
332                    let period = self.state.config.adx_smoothing as f64;
333                    let adx = (prev_adx * (period - 1.0) + dx) / period;
334                    self.state.current_adx = Some(adx);
335                    adx
336                } else {
337                    0.0
338                }
339            }
340        } else {
341            0.0
342        }
343    }
344
345    fn classify_trend_strength(&self, adx: f64) -> TrendStrength {
346        if !self.state.has_adx_data {
347            TrendStrength::Insufficient
348        } else if adx >= self.state.config.very_strong_trend_threshold {
349            TrendStrength::VeryStrong
350        } else if adx >= self.state.config.strong_trend_threshold {
351            TrendStrength::Strong
352        } else {
353            TrendStrength::Weak
354        }
355    }
356
357    fn determine_trend_direction(&self, plus_di: f64, minus_di: f64) -> TrendDirection {
358        if plus_di > minus_di {
359            TrendDirection::Up
360        } else if minus_di > plus_di {
361            TrendDirection::Down
362        } else {
363            TrendDirection::Sideways
364        }
365    }
366}
367
368impl Default for ADX {
369    fn default() -> Self {
370        Self::new()
371    }
372}
373
374/// Convenience function to calculate ADX for HLC data without maintaining state
375pub fn calculate_adx_simple(highs: &[f64], lows: &[f64], closes: &[f64], period: usize) -> Result<Vec<f64>, ADXError> {
376    let len = highs.len();
377    if len != lows.len() || len != closes.len() {
378        return Err(ADXError::InvalidInput("All price arrays must have same length".to_string()));
379    }
380
381    if len == 0 {
382        return Ok(Vec::new());
383    }
384
385    let mut adx_calculator = ADX::with_period(period)?;
386    let mut results = Vec::with_capacity(len);
387
388    for i in 0..len {
389        let input = ADXInput {
390            high: highs[i],
391            low: lows[i],
392            close: closes[i],
393        };
394        let output = adx_calculator.calculate(input)?;
395        results.push(output.adx);
396    }
397
398    Ok(results)
399}