1use crate::traits::Next;
4
5#[derive(Debug, Clone)]
7struct DmiCore {
8 timeperiod: usize,
9 period_f: f64,
10 prev_high: Option<f64>,
11 prev_low: Option<f64>,
12 prev_close: Option<f64>,
13 bar_index: usize,
14 sum_tr: f64,
15 sum_pdm: f64,
16 sum_mdm: f64,
17 seeded: bool,
18}
19
20impl DmiCore {
21 fn new(timeperiod: usize) -> Self {
22 Self {
23 timeperiod,
24 period_f: timeperiod as f64,
25 prev_high: None,
26 prev_low: None,
27 prev_close: None,
28 bar_index: 0,
29 sum_tr: 0.0,
30 sum_pdm: 0.0,
31 sum_mdm: 0.0,
32 seeded: false,
33 }
34 }
35
36 #[inline]
37 fn dm_components(&self, high: f64, low: f64) -> (f64, f64, f64) {
38 let ph = self.prev_high.unwrap();
39 let pl = self.prev_low.unwrap();
40 let pc = self.prev_close.unwrap();
41 let hl = high - low;
42 let hc = (high - pc).abs();
43 let lc = (low - pc).abs();
44 let tr = hl.max(hc).max(lc);
45 let up = high - ph;
46 let down = pl - low;
47 let pdm = if up > down && up > 0.0 { up } else { 0.0 };
48 let mdm = if down > up && down > 0.0 { down } else { 0.0 };
49 (tr, pdm, mdm)
50 }
51
52 fn step(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64, f64)> {
54 let period = self.timeperiod;
55 if period < 1 {
56 return None;
57 }
58
59 if self.prev_high.is_none() {
60 self.prev_high = Some(high);
61 self.prev_low = Some(low);
62 self.prev_close = Some(close);
63 self.bar_index = 1;
64 return None;
65 }
66
67 let (tr, pdm, mdm) = self.dm_components(high, low);
68 self.prev_high = Some(high);
69 self.prev_low = Some(low);
70 self.prev_close = Some(close);
71 let i = self.bar_index;
72 self.bar_index += 1;
73
74 if !self.seeded {
75 if i < period {
76 self.sum_tr += tr;
77 self.sum_pdm += pdm;
78 self.sum_mdm += mdm;
79 return None;
80 }
81 self.seeded = true;
82 }
83 self.sum_tr = self.sum_tr - self.sum_tr / self.period_f + tr;
84 self.sum_pdm = self.sum_pdm - self.sum_pdm / self.period_f + pdm;
85 self.sum_mdm = self.sum_mdm - self.sum_mdm / self.period_f + mdm;
86
87 if self.sum_tr <= 0.0 {
88 return None;
89 }
90 let pdi = 100.0 * self.sum_pdm / self.sum_tr;
91 let mdi = 100.0 * self.sum_mdm / self.sum_tr;
92 let sum_di = pdi + mdi;
93 let dx = if sum_di > 0.0 {
94 100.0 * (pdi - mdi).abs() / sum_di
95 } else {
96 0.0
97 };
98 Some((pdi, mdi, dx))
99 }
100}
101
102#[derive(Debug, Clone)]
104#[allow(non_camel_case_types)]
105pub struct PLUS_DI {
106 pub timeperiod: usize,
107 core: DmiCore,
108}
109
110impl PLUS_DI {
111 pub fn new(timeperiod: usize) -> Self {
112 Self {
113 timeperiod,
114 core: DmiCore::new(timeperiod),
115 }
116 }
117}
118
119impl Next<(f64, f64, f64)> for PLUS_DI {
120 type Output = f64;
121
122 fn next(&mut self, (high, low, close): (f64, f64, f64)) -> Self::Output {
123 match self.core.step(high, low, close) {
124 Some((pdi, _, _)) => pdi,
125 None => f64::NAN,
126 }
127 }
128}
129
130#[derive(Debug, Clone)]
132#[allow(non_camel_case_types)]
133pub struct MINUS_DI {
134 pub timeperiod: usize,
135 core: DmiCore,
136}
137
138impl MINUS_DI {
139 pub fn new(timeperiod: usize) -> Self {
140 Self {
141 timeperiod,
142 core: DmiCore::new(timeperiod),
143 }
144 }
145}
146
147impl Next<(f64, f64, f64)> for MINUS_DI {
148 type Output = f64;
149
150 fn next(&mut self, (high, low, close): (f64, f64, f64)) -> Self::Output {
151 match self.core.step(high, low, close) {
152 Some((_, mdi, _)) => mdi,
153 None => f64::NAN,
154 }
155 }
156}
157
158#[derive(Debug, Clone)]
160#[allow(non_camel_case_types)]
161pub struct DX {
162 pub timeperiod: usize,
163 core: DmiCore,
164}
165
166impl DX {
167 pub fn new(timeperiod: usize) -> Self {
168 Self {
169 timeperiod,
170 core: DmiCore::new(timeperiod),
171 }
172 }
173}
174
175impl Next<(f64, f64, f64)> for DX {
176 type Output = f64;
177
178 fn next(&mut self, (high, low, close): (f64, f64, f64)) -> Self::Output {
179 match self.core.step(high, low, close) {
180 Some((_, _, dx)) => dx,
181 None => f64::NAN,
182 }
183 }
184}
185
186#[derive(Debug, Clone)]
188#[allow(non_camel_case_types)]
189pub struct ADX {
190 pub timeperiod: usize,
191 core: DmiCore,
192 dx_values: Vec<f64>,
193 adx: f64,
194 adx_ready: bool,
195}
196
197impl ADX {
198 pub fn new(timeperiod: usize) -> Self {
199 Self {
200 timeperiod,
201 core: DmiCore::new(timeperiod),
202 dx_values: Vec::new(),
203 adx: 0.0,
204 adx_ready: false,
205 }
206 }
207}
208
209impl Next<(f64, f64, f64)> for ADX {
210 type Output = f64;
211
212 fn next(&mut self, (high, low, close): (f64, f64, f64)) -> Self::Output {
213 let period = self.timeperiod;
214 if period < 2 {
215 return f64::NAN;
216 }
217
218 let Some((_, _, dx)) = self.core.step(high, low, close) else {
219 return f64::NAN;
220 };
221
222 let adx_start = 2 * period - 1;
223 let bar = self.core.bar_index.saturating_sub(1);
224
225 if bar < period {
226 return f64::NAN;
227 }
228
229 if bar < adx_start {
230 self.dx_values.push(dx);
231 return f64::NAN;
232 }
233
234 if bar == adx_start {
235 self.dx_values.push(dx);
236 let seed: f64 = self.dx_values.iter().sum::<f64>() / period as f64;
237 self.adx = seed;
238 self.adx_ready = true;
239 return seed;
240 }
241
242 if self.adx_ready {
243 self.adx = (self.adx * (period as f64 - 1.0) + dx) / period as f64;
244 return self.adx;
245 }
246
247 f64::NAN
248 }
249}
250
251#[derive(Debug, Clone)]
253#[allow(non_camel_case_types)]
254pub struct ADXR {
255 pub timeperiod: usize,
256 adx: ADX,
257 adx_history: Vec<f64>,
258}
259
260impl ADXR {
261 pub fn new(timeperiod: usize) -> Self {
262 Self {
263 timeperiod,
264 adx: ADX::new(timeperiod),
265 adx_history: Vec::new(),
266 }
267 }
268}
269
270impl Next<(f64, f64, f64)> for ADXR {
271 type Output = f64;
272
273 fn next(&mut self, (high, low, close): (f64, f64, f64)) -> Self::Output {
274 let period = self.timeperiod;
275 let adx_val = self.adx.next((high, low, close));
276 self.adx_history.push(adx_val);
277
278 let adxr_lookback = 3 * period - 2;
279 let bar = self.adx_history.len().saturating_sub(1);
280 if bar < adxr_lookback {
281 return f64::NAN;
282 }
283 if adx_val.is_nan() {
284 return f64::NAN;
285 }
286 let past_idx = bar + 1 - period;
287 let past = self.adx_history[past_idx];
288 if past.is_nan() {
289 return f64::NAN;
290 }
291 (adx_val + past) / 2.0
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298 use proptest::prelude::*;
299
300 fn hlc(len: usize, h: &[f64], l: &[f64], c: &[f64]) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
301 let mut high = Vec::with_capacity(len);
302 let mut low = Vec::with_capacity(len);
303 let mut close = Vec::with_capacity(len);
304 for i in 0..len {
305 let val_h = h[i];
306 let val_l = l[i];
307 let val_c = c[i];
308 high.push(val_h.max(val_l).max(val_c));
309 low.push(val_h.min(val_l).min(val_c));
310 close.push(val_c);
311 }
312 (high, low, close)
313 }
314
315 proptest! {
316 #[test]
317 fn test_adx_parity(
318 h in prop::collection::vec(1.0..100.0, 1..100),
319 l in prop::collection::vec(1.0..100.0, 1..100),
320 c in prop::collection::vec(1.0..100.0, 1..100)
321 ) {
322 let len = h.len().min(l.len()).min(c.len());
323 if len < 30 { return Ok(()); }
324 let (high, low, close) = hlc(len, &h, &l, &c);
325 let period = 14;
326 let mut adx = ADX::new(period);
327 let streaming: Vec<f64> = (0..len)
328 .map(|i| adx.next((high[i], low[i], close[i])))
329 .collect();
330 let batch = talib_rs::momentum::adx(&high, &low, &close, period)
331 .unwrap_or_else(|_| vec![f64::NAN; len]);
332 for (s, b) in streaming.iter().zip(batch.iter()) {
333 if s.is_nan() { assert!(b.is_nan()); }
334 else { approx::assert_relative_eq!(s, b, epsilon = 1e-6); }
335 }
336 }
337
338 #[test]
339 fn test_dx_parity(
340 h in prop::collection::vec(1.0..100.0, 1..100),
341 l in prop::collection::vec(1.0..100.0, 1..100),
342 c in prop::collection::vec(1.0..100.0, 1..100)
343 ) {
344 let len = h.len().min(l.len()).min(c.len());
345 if len < 20 { return Ok(()); }
346 let (high, low, close) = hlc(len, &h, &l, &c);
347 let period = 14;
348 let mut dx = DX::new(period);
349 let streaming: Vec<f64> = (0..len)
350 .map(|i| dx.next((high[i], low[i], close[i])))
351 .collect();
352 let batch = talib_rs::momentum::dx(&high, &low, &close, period)
353 .unwrap_or_else(|_| vec![f64::NAN; len]);
354 for (s, b) in streaming.iter().zip(batch.iter()) {
355 if s.is_nan() { assert!(b.is_nan()); }
356 else { approx::assert_relative_eq!(s, b, epsilon = 1e-6); }
357 }
358 }
359
360 #[test]
361 fn test_plus_di_parity(
362 h in prop::collection::vec(1.0..100.0, 1..100),
363 l in prop::collection::vec(1.0..100.0, 1..100),
364 c in prop::collection::vec(1.0..100.0, 1..100)
365 ) {
366 let len = h.len().min(l.len()).min(c.len());
367 if len < 20 { return Ok(()); }
368 let (high, low, close) = hlc(len, &h, &l, &c);
369 let period = 14;
370 let mut pdi = PLUS_DI::new(period);
371 let streaming: Vec<f64> = (0..len)
372 .map(|i| pdi.next((high[i], low[i], close[i])))
373 .collect();
374 let batch = talib_rs::momentum::plus_di(&high, &low, &close, period)
375 .unwrap_or_else(|_| vec![f64::NAN; len]);
376 for (s, b) in streaming.iter().zip(batch.iter()) {
377 if s.is_nan() { assert!(b.is_nan()); }
378 else { approx::assert_relative_eq!(s, b, epsilon = 1e-6); }
379 }
380 }
381 }
382}