Skip to main content

mantis_ta/indicators/trend/
adx.rs

1use crate::indicators::Indicator;
2use crate::types::{AdxOutput, Candle};
3use crate::utils::ringbuf::RingBuf;
4
5/// Average Directional Index measuring trend strength.
6///
7/// ADX combines +DI and -DI to measure trend strength (0-100).
8/// Output includes +DI, -DI, and ADX values.
9///
10/// # Examples
11/// ```rust
12/// use mantis_ta::indicators::{Indicator, ADX};
13/// use mantis_ta::types::Candle;
14///
15/// let candles: Vec<Candle> = (0..50)
16///     .map(|i| {
17///         let price = 100.0 + i as f64;
18///         Candle {
19///             timestamp: i as i64,
20///             open: price,
21///             high: price + 1.0,
22///             low: price - 1.0,
23///             close: price,
24///             volume: 0.0,
25///         }
26///     })
27///     .collect();
28///
29/// let out = ADX::new(14).calculate(&candles);
30/// // Warmup period = 14 * 2 = 28
31/// assert!(out.iter().take(27).all(|v| v.is_none()));
32/// assert!(out.iter().skip(27).any(|v| v.is_some()));
33/// ```
34#[derive(Debug, Clone)]
35pub struct ADX {
36    period: usize,
37    prev_high: Option<f64>,
38    prev_low: Option<f64>,
39    prev_close: Option<f64>,
40    tr_sum: f64,
41    plus_dm_sum: f64,
42    minus_dm_sum: f64,
43    bar_count: usize,
44    plus_di: Option<f64>,
45    minus_di: Option<f64>,
46    di_history: RingBuf<f64>,
47    adx: Option<f64>,
48}
49
50impl ADX {
51    pub fn new(period: usize) -> Self {
52        assert!(period > 0, "period must be > 0");
53        Self {
54            period,
55            prev_high: None,
56            prev_low: None,
57            prev_close: None,
58            tr_sum: 0.0,
59            plus_dm_sum: 0.0,
60            minus_dm_sum: 0.0,
61            bar_count: 0,
62            plus_di: None,
63            minus_di: None,
64            di_history: RingBuf::new(period, 0.0),
65            adx: None,
66        }
67    }
68
69    #[inline]
70    fn true_range(high: f64, low: f64, prev_close: f64) -> f64 {
71        let hl = high - low;
72        let hc = (high - prev_close).abs();
73        let lc = (low - prev_close).abs();
74        hl.max(hc).max(lc)
75    }
76
77    #[inline]
78    fn update(&mut self, high: f64, low: f64, close: f64) -> Option<AdxOutput> {
79        self.bar_count += 1;
80
81        if let (Some(ph), Some(pl), Some(pc)) = (self.prev_high, self.prev_low, self.prev_close) {
82            let tr = Self::true_range(high, low, pc);
83            let up_move = high - ph;
84            let down_move = pl - low;
85
86            let plus_dm = if up_move > down_move && up_move > 0.0 {
87                up_move
88            } else {
89                0.0
90            };
91
92            let minus_dm = if down_move > up_move && down_move > 0.0 {
93                down_move
94            } else {
95                0.0
96            };
97
98            if self.bar_count <= self.period {
99                self.tr_sum += tr;
100                self.plus_dm_sum += plus_dm;
101                self.minus_dm_sum += minus_dm;
102            } else if self.bar_count == self.period + 1 {
103                self.tr_sum = self.tr_sum - self.tr_sum / self.period as f64 + tr;
104                self.plus_dm_sum =
105                    self.plus_dm_sum - self.plus_dm_sum / self.period as f64 + plus_dm;
106                self.minus_dm_sum =
107                    self.minus_dm_sum - self.minus_dm_sum / self.period as f64 + minus_dm;
108                self.plus_di = Some((self.plus_dm_sum / self.tr_sum) * 100.0);
109                self.minus_di = Some((self.minus_dm_sum / self.tr_sum) * 100.0);
110            } else {
111                self.tr_sum = self.tr_sum - self.tr_sum / self.period as f64 + tr;
112                self.plus_dm_sum =
113                    self.plus_dm_sum - self.plus_dm_sum / self.period as f64 + plus_dm;
114                self.minus_dm_sum =
115                    self.minus_dm_sum - self.minus_dm_sum / self.period as f64 + minus_dm;
116
117                let new_plus_di = (self.plus_dm_sum / self.tr_sum) * 100.0;
118                let new_minus_di = (self.minus_dm_sum / self.tr_sum) * 100.0;
119                self.plus_di = Some(new_plus_di);
120                self.minus_di = Some(new_minus_di);
121
122                let di_diff = (new_plus_di - new_minus_di).abs();
123                let di_sum = new_plus_di + new_minus_di;
124                let dx = if di_sum > 0.0 {
125                    (di_diff / di_sum) * 100.0
126                } else {
127                    0.0
128                };
129
130                self.di_history.push(dx);
131
132                match self.bar_count.cmp(&(self.period * 2)) {
133                    std::cmp::Ordering::Equal => {
134                        let adx_sum: f64 = self.di_history.iter().sum();
135                        self.adx = Some(adx_sum / self.period as f64);
136                    }
137                    std::cmp::Ordering::Greater => {
138                        let prev_adx = self.adx.unwrap_or(0.0);
139                        self.adx =
140                            Some((prev_adx * (self.period - 1) as f64 + dx) / self.period as f64);
141                    }
142                    std::cmp::Ordering::Less => {}
143                }
144            }
145        }
146
147        self.prev_high = Some(high);
148        self.prev_low = Some(low);
149        self.prev_close = Some(close);
150
151        if self.bar_count > self.period * 2 {
152            Some(AdxOutput {
153                plus_di: self.plus_di.unwrap_or(0.0),
154                minus_di: self.minus_di.unwrap_or(0.0),
155                adx: self.adx.unwrap_or(0.0),
156            })
157        } else {
158            None
159        }
160    }
161}
162
163impl Indicator for ADX {
164    type Output = AdxOutput;
165
166    fn next(&mut self, candle: &Candle) -> Option<Self::Output> {
167        self.update(candle.high, candle.low, candle.close)
168    }
169
170    fn reset(&mut self) {
171        self.prev_high = None;
172        self.prev_low = None;
173        self.prev_close = None;
174        self.tr_sum = 0.0;
175        self.plus_dm_sum = 0.0;
176        self.minus_dm_sum = 0.0;
177        self.bar_count = 0;
178        self.plus_di = None;
179        self.minus_di = None;
180        self.di_history = RingBuf::new(self.period, 0.0);
181        self.adx = None;
182    }
183
184    fn warmup_period(&self) -> usize {
185        self.period * 2
186    }
187
188    fn clone_boxed(&self) -> Box<dyn Indicator<Output = Self::Output>> {
189        Box::new(self.clone())
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196
197    #[test]
198    fn adx_emits_after_warmup() {
199        let candles: Vec<Candle> = (0..10)
200            .map(|i| {
201                let price = 100.0 + i as f64;
202                Candle {
203                    timestamp: i as i64,
204                    open: price,
205                    high: price + 1.0,
206                    low: price - 1.0,
207                    close: price,
208                    volume: 0.0,
209                }
210            })
211            .collect();
212
213        let out = ADX::new(3).calculate(&candles);
214        assert!(out.iter().take(5).all(|v| v.is_none()));
215        assert!(out.iter().skip(5).any(|v| v.is_some()));
216    }
217}