Skip to main content

quantwave_core/indicators/incremental/
dmi.rs

1//! Native O(1) DMI family — TA-Lib Wilder smoothing parity (ADX, ADXR, DX, +DI, -DI).
2
3use crate::traits::Next;
4
5/// Shared Wilder-smoothed TR / +DM / -DM state.
6#[derive(Debug, Clone)]
7struct DmiCore {
8    timeperiod: usize,
9    period_f: f64,
10    prev_high: Option<f64>,
11    prev_low: Option<f64>,
12    prev_close: Option<f64>,
13    bar_index: usize,
14    sum_tr: f64,
15    sum_pdm: f64,
16    sum_mdm: f64,
17    seeded: bool,
18}
19
20impl DmiCore {
21    fn new(timeperiod: usize) -> Self {
22        Self {
23            timeperiod,
24            period_f: timeperiod as f64,
25            prev_high: None,
26            prev_low: None,
27            prev_close: None,
28            bar_index: 0,
29            sum_tr: 0.0,
30            sum_pdm: 0.0,
31            sum_mdm: 0.0,
32            seeded: false,
33        }
34    }
35
36    #[inline]
37    fn dm_components(&self, high: f64, low: f64) -> (f64, f64, f64) {
38        let ph = self.prev_high.unwrap();
39        let pl = self.prev_low.unwrap();
40        let pc = self.prev_close.unwrap();
41        let hl = high - low;
42        let hc = (high - pc).abs();
43        let lc = (low - pc).abs();
44        let tr = hl.max(hc).max(lc);
45        let up = high - ph;
46        let down = pl - low;
47        let pdm = if up > down && up > 0.0 { up } else { 0.0 };
48        let mdm = if down > up && down > 0.0 { down } else { 0.0 };
49        (tr, pdm, mdm)
50    }
51
52    /// Advance one bar; returns `Some((pdi, mdi, dx))` once DI values are defined.
53    fn step(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64, f64)> {
54        let period = self.timeperiod;
55        if period < 1 {
56            return None;
57        }
58
59        if self.prev_high.is_none() {
60            self.prev_high = Some(high);
61            self.prev_low = Some(low);
62            self.prev_close = Some(close);
63            self.bar_index = 1;
64            return None;
65        }
66
67        let (tr, pdm, mdm) = self.dm_components(high, low);
68        self.prev_high = Some(high);
69        self.prev_low = Some(low);
70        self.prev_close = Some(close);
71        let i = self.bar_index;
72        self.bar_index += 1;
73
74        if !self.seeded {
75            if i < period {
76                self.sum_tr += tr;
77                self.sum_pdm += pdm;
78                self.sum_mdm += mdm;
79                return None;
80            }
81            self.seeded = true;
82        }
83        self.sum_tr = self.sum_tr - self.sum_tr / self.period_f + tr;
84        self.sum_pdm = self.sum_pdm - self.sum_pdm / self.period_f + pdm;
85        self.sum_mdm = self.sum_mdm - self.sum_mdm / self.period_f + mdm;
86
87        if self.sum_tr <= 0.0 {
88            return None;
89        }
90        let pdi = 100.0 * self.sum_pdm / self.sum_tr;
91        let mdi = 100.0 * self.sum_mdm / self.sum_tr;
92        let sum_di = pdi + mdi;
93        let dx = if sum_di > 0.0 {
94            100.0 * (pdi - mdi).abs() / sum_di
95        } else {
96            0.0
97        };
98        Some((pdi, mdi, dx))
99    }
100}
101
102/// Plus Directional Indicator (+DI).
103#[derive(Debug, Clone)]
104#[allow(non_camel_case_types)]
105pub struct PLUS_DI {
106    pub timeperiod: usize,
107    core: DmiCore,
108}
109
110impl PLUS_DI {
111    pub fn new(timeperiod: usize) -> Self {
112        Self {
113            timeperiod,
114            core: DmiCore::new(timeperiod),
115        }
116    }
117}
118
119impl Next<(f64, f64, f64)> for PLUS_DI {
120    type Output = f64;
121
122    fn next(&mut self, (high, low, close): (f64, f64, f64)) -> Self::Output {
123        match self.core.step(high, low, close) {
124            Some((pdi, _, _)) => pdi,
125            None => f64::NAN,
126        }
127    }
128}
129
130/// Minus Directional Indicator (-DI).
131#[derive(Debug, Clone)]
132#[allow(non_camel_case_types)]
133pub struct MINUS_DI {
134    pub timeperiod: usize,
135    core: DmiCore,
136}
137
138impl MINUS_DI {
139    pub fn new(timeperiod: usize) -> Self {
140        Self {
141            timeperiod,
142            core: DmiCore::new(timeperiod),
143        }
144    }
145}
146
147impl Next<(f64, f64, f64)> for MINUS_DI {
148    type Output = f64;
149
150    fn next(&mut self, (high, low, close): (f64, f64, f64)) -> Self::Output {
151        match self.core.step(high, low, close) {
152            Some((_, mdi, _)) => mdi,
153            None => f64::NAN,
154        }
155    }
156}
157
158/// Directional Movement Index (DX).
159#[derive(Debug, Clone)]
160#[allow(non_camel_case_types)]
161pub struct DX {
162    pub timeperiod: usize,
163    core: DmiCore,
164}
165
166impl DX {
167    pub fn new(timeperiod: usize) -> Self {
168        Self {
169            timeperiod,
170            core: DmiCore::new(timeperiod),
171        }
172    }
173}
174
175impl Next<(f64, f64, f64)> for DX {
176    type Output = f64;
177
178    fn next(&mut self, (high, low, close): (f64, f64, f64)) -> Self::Output {
179        match self.core.step(high, low, close) {
180            Some((_, _, dx)) => dx,
181            None => f64::NAN,
182        }
183    }
184}
185
186/// Average Directional Index (ADX).
187#[derive(Debug, Clone)]
188#[allow(non_camel_case_types)]
189pub struct ADX {
190    pub timeperiod: usize,
191    core: DmiCore,
192    dx_values: Vec<f64>,
193    adx: f64,
194    adx_ready: bool,
195}
196
197impl ADX {
198    pub fn new(timeperiod: usize) -> Self {
199        Self {
200            timeperiod,
201            core: DmiCore::new(timeperiod),
202            dx_values: Vec::new(),
203            adx: 0.0,
204            adx_ready: false,
205        }
206    }
207}
208
209impl Next<(f64, f64, f64)> for ADX {
210    type Output = f64;
211
212    fn next(&mut self, (high, low, close): (f64, f64, f64)) -> Self::Output {
213        let period = self.timeperiod;
214        if period < 2 {
215            return f64::NAN;
216        }
217
218        let Some((_, _, dx)) = self.core.step(high, low, close) else {
219            return f64::NAN;
220        };
221
222        let adx_start = 2 * period - 1;
223        let bar = self.core.bar_index.saturating_sub(1);
224
225        if bar < period {
226            return f64::NAN;
227        }
228
229        if bar < adx_start {
230            self.dx_values.push(dx);
231            return f64::NAN;
232        }
233
234        if bar == adx_start {
235            self.dx_values.push(dx);
236            let seed: f64 = self.dx_values.iter().sum::<f64>() / period as f64;
237            self.adx = seed;
238            self.adx_ready = true;
239            return seed;
240        }
241
242        if self.adx_ready {
243            self.adx = (self.adx * (period as f64 - 1.0) + dx) / period as f64;
244            return self.adx;
245        }
246
247        f64::NAN
248    }
249}
250
251/// Average Directional Movement Index Rating (ADXR).
252#[derive(Debug, Clone)]
253#[allow(non_camel_case_types)]
254pub struct ADXR {
255    pub timeperiod: usize,
256    adx: ADX,
257    adx_history: Vec<f64>,
258}
259
260impl ADXR {
261    pub fn new(timeperiod: usize) -> Self {
262        Self {
263            timeperiod,
264            adx: ADX::new(timeperiod),
265            adx_history: Vec::new(),
266        }
267    }
268}
269
270impl Next<(f64, f64, f64)> for ADXR {
271    type Output = f64;
272
273    fn next(&mut self, (high, low, close): (f64, f64, f64)) -> Self::Output {
274        let period = self.timeperiod;
275        let adx_val = self.adx.next((high, low, close));
276        self.adx_history.push(adx_val);
277
278        let adxr_lookback = 3 * period - 2;
279        let bar = self.adx_history.len().saturating_sub(1);
280        if bar < adxr_lookback {
281            return f64::NAN;
282        }
283        if adx_val.is_nan() {
284            return f64::NAN;
285        }
286        let past_idx = bar + 1 - period;
287        let past = self.adx_history[past_idx];
288        if past.is_nan() {
289            return f64::NAN;
290        }
291        (adx_val + past) / 2.0
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298    use proptest::prelude::*;
299
300    fn hlc(len: usize, h: &[f64], l: &[f64], c: &[f64]) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
301        let mut high = Vec::with_capacity(len);
302        let mut low = Vec::with_capacity(len);
303        let mut close = Vec::with_capacity(len);
304        for i in 0..len {
305            let val_h = h[i];
306            let val_l = l[i];
307            let val_c = c[i];
308            high.push(val_h.max(val_l).max(val_c));
309            low.push(val_h.min(val_l).min(val_c));
310            close.push(val_c);
311        }
312        (high, low, close)
313    }
314
315    proptest! {
316        #[test]
317        fn test_adx_parity(
318            h in prop::collection::vec(1.0..100.0, 1..100),
319            l in prop::collection::vec(1.0..100.0, 1..100),
320            c in prop::collection::vec(1.0..100.0, 1..100)
321        ) {
322            let len = h.len().min(l.len()).min(c.len());
323            if len < 30 { return Ok(()); }
324            let (high, low, close) = hlc(len, &h, &l, &c);
325            let period = 14;
326            let mut adx = ADX::new(period);
327            let streaming: Vec<f64> = (0..len)
328                .map(|i| adx.next((high[i], low[i], close[i])))
329                .collect();
330            let batch = talib_rs::momentum::adx(&high, &low, &close, period)
331                .unwrap_or_else(|_| vec![f64::NAN; len]);
332            for (s, b) in streaming.iter().zip(batch.iter()) {
333                if s.is_nan() { assert!(b.is_nan()); }
334                else { approx::assert_relative_eq!(s, b, epsilon = 1e-6); }
335            }
336        }
337
338        #[test]
339        fn test_dx_parity(
340            h in prop::collection::vec(1.0..100.0, 1..100),
341            l in prop::collection::vec(1.0..100.0, 1..100),
342            c in prop::collection::vec(1.0..100.0, 1..100)
343        ) {
344            let len = h.len().min(l.len()).min(c.len());
345            if len < 20 { return Ok(()); }
346            let (high, low, close) = hlc(len, &h, &l, &c);
347            let period = 14;
348            let mut dx = DX::new(period);
349            let streaming: Vec<f64> = (0..len)
350                .map(|i| dx.next((high[i], low[i], close[i])))
351                .collect();
352            let batch = talib_rs::momentum::dx(&high, &low, &close, period)
353                .unwrap_or_else(|_| vec![f64::NAN; len]);
354            for (s, b) in streaming.iter().zip(batch.iter()) {
355                if s.is_nan() { assert!(b.is_nan()); }
356                else { approx::assert_relative_eq!(s, b, epsilon = 1e-6); }
357            }
358        }
359
360        #[test]
361        fn test_plus_di_parity(
362            h in prop::collection::vec(1.0..100.0, 1..100),
363            l in prop::collection::vec(1.0..100.0, 1..100),
364            c in prop::collection::vec(1.0..100.0, 1..100)
365        ) {
366            let len = h.len().min(l.len()).min(c.len());
367            if len < 20 { return Ok(()); }
368            let (high, low, close) = hlc(len, &h, &l, &c);
369            let period = 14;
370            let mut pdi = PLUS_DI::new(period);
371            let streaming: Vec<f64> = (0..len)
372                .map(|i| pdi.next((high[i], low[i], close[i])))
373                .collect();
374            let batch = talib_rs::momentum::plus_di(&high, &low, &close, period)
375                .unwrap_or_else(|_| vec![f64::NAN; len]);
376            for (s, b) in streaming.iter().zip(batch.iter()) {
377                if s.is_nan() { assert!(b.is_nan()); }
378                else { approx::assert_relative_eq!(s, b, epsilon = 1e-6); }
379            }
380        }
381    }
382}