mantis_ta/indicators/trend/
adx.rs1use crate::indicators::Indicator;
2use crate::types::{AdxOutput, Candle};
3use crate::utils::ringbuf::RingBuf;
4
5#[derive(Debug, Clone)]
35pub struct ADX {
36 period: usize,
37 prev_high: Option<f64>,
38 prev_low: Option<f64>,
39 prev_close: Option<f64>,
40 tr_sum: f64,
41 plus_dm_sum: f64,
42 minus_dm_sum: f64,
43 bar_count: usize,
44 plus_di: Option<f64>,
45 minus_di: Option<f64>,
46 di_history: RingBuf<f64>,
47 adx: Option<f64>,
48}
49
50impl ADX {
51 pub fn new(period: usize) -> Self {
52 assert!(period > 0, "period must be > 0");
53 Self {
54 period,
55 prev_high: None,
56 prev_low: None,
57 prev_close: None,
58 tr_sum: 0.0,
59 plus_dm_sum: 0.0,
60 minus_dm_sum: 0.0,
61 bar_count: 0,
62 plus_di: None,
63 minus_di: None,
64 di_history: RingBuf::new(period, 0.0),
65 adx: None,
66 }
67 }
68
69 #[inline]
70 fn true_range(high: f64, low: f64, prev_close: f64) -> f64 {
71 let hl = high - low;
72 let hc = (high - prev_close).abs();
73 let lc = (low - prev_close).abs();
74 hl.max(hc).max(lc)
75 }
76
77 #[inline]
78 fn update(&mut self, high: f64, low: f64, close: f64) -> Option<AdxOutput> {
79 self.bar_count += 1;
80
81 if let (Some(ph), Some(pl), Some(pc)) = (self.prev_high, self.prev_low, self.prev_close) {
82 let tr = Self::true_range(high, low, pc);
83 let up_move = high - ph;
84 let down_move = pl - low;
85
86 let plus_dm = if up_move > down_move && up_move > 0.0 {
87 up_move
88 } else {
89 0.0
90 };
91
92 let minus_dm = if down_move > up_move && down_move > 0.0 {
93 down_move
94 } else {
95 0.0
96 };
97
98 if self.bar_count <= self.period {
99 self.tr_sum += tr;
100 self.plus_dm_sum += plus_dm;
101 self.minus_dm_sum += minus_dm;
102 } else if self.bar_count == self.period + 1 {
103 self.tr_sum = self.tr_sum - self.tr_sum / self.period as f64 + tr;
104 self.plus_dm_sum =
105 self.plus_dm_sum - self.plus_dm_sum / self.period as f64 + plus_dm;
106 self.minus_dm_sum =
107 self.minus_dm_sum - self.minus_dm_sum / self.period as f64 + minus_dm;
108 self.plus_di = Some((self.plus_dm_sum / self.tr_sum) * 100.0);
109 self.minus_di = Some((self.minus_dm_sum / self.tr_sum) * 100.0);
110 } else {
111 self.tr_sum = self.tr_sum - self.tr_sum / self.period as f64 + tr;
112 self.plus_dm_sum =
113 self.plus_dm_sum - self.plus_dm_sum / self.period as f64 + plus_dm;
114 self.minus_dm_sum =
115 self.minus_dm_sum - self.minus_dm_sum / self.period as f64 + minus_dm;
116
117 let new_plus_di = (self.plus_dm_sum / self.tr_sum) * 100.0;
118 let new_minus_di = (self.minus_dm_sum / self.tr_sum) * 100.0;
119 self.plus_di = Some(new_plus_di);
120 self.minus_di = Some(new_minus_di);
121
122 let di_diff = (new_plus_di - new_minus_di).abs();
123 let di_sum = new_plus_di + new_minus_di;
124 let dx = if di_sum > 0.0 {
125 (di_diff / di_sum) * 100.0
126 } else {
127 0.0
128 };
129
130 self.di_history.push(dx);
131
132 match self.bar_count.cmp(&(self.period * 2)) {
133 std::cmp::Ordering::Equal => {
134 let adx_sum: f64 = self.di_history.iter().sum();
135 self.adx = Some(adx_sum / self.period as f64);
136 }
137 std::cmp::Ordering::Greater => {
138 let prev_adx = self.adx.unwrap_or(0.0);
139 self.adx =
140 Some((prev_adx * (self.period - 1) as f64 + dx) / self.period as f64);
141 }
142 std::cmp::Ordering::Less => {}
143 }
144 }
145 }
146
147 self.prev_high = Some(high);
148 self.prev_low = Some(low);
149 self.prev_close = Some(close);
150
151 if self.bar_count > self.period * 2 {
152 Some(AdxOutput {
153 plus_di: self.plus_di.unwrap_or(0.0),
154 minus_di: self.minus_di.unwrap_or(0.0),
155 adx: self.adx.unwrap_or(0.0),
156 })
157 } else {
158 None
159 }
160 }
161}
162
163impl Indicator for ADX {
164 type Output = AdxOutput;
165
166 fn next(&mut self, candle: &Candle) -> Option<Self::Output> {
167 self.update(candle.high, candle.low, candle.close)
168 }
169
170 fn reset(&mut self) {
171 self.prev_high = None;
172 self.prev_low = None;
173 self.prev_close = None;
174 self.tr_sum = 0.0;
175 self.plus_dm_sum = 0.0;
176 self.minus_dm_sum = 0.0;
177 self.bar_count = 0;
178 self.plus_di = None;
179 self.minus_di = None;
180 self.di_history = RingBuf::new(self.period, 0.0);
181 self.adx = None;
182 }
183
184 fn warmup_period(&self) -> usize {
185 self.period * 2
186 }
187
188 fn clone_boxed(&self) -> Box<dyn Indicator<Output = Self::Output>> {
189 Box::new(self.clone())
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196
197 #[test]
198 fn adx_emits_after_warmup() {
199 let candles: Vec<Candle> = (0..10)
200 .map(|i| {
201 let price = 100.0 + i as f64;
202 Candle {
203 timestamp: i as i64,
204 open: price,
205 high: price + 1.0,
206 low: price - 1.0,
207 close: price,
208 volume: 0.0,
209 }
210 })
211 .collect();
212
213 let out = ADX::new(3).calculate(&candles);
214 assert!(out.iter().take(5).all(|v| v.is_none()));
215 assert!(out.iter().skip(5).any(|v| v.is_some()));
216 }
217}