use crate::traits::Next;
#[derive(Debug, Clone)]
struct DmiCore {
timeperiod: usize,
period_f: f64,
prev_high: Option<f64>,
prev_low: Option<f64>,
prev_close: Option<f64>,
bar_index: usize,
sum_tr: f64,
sum_pdm: f64,
sum_mdm: f64,
seeded: bool,
}
impl DmiCore {
fn new(timeperiod: usize) -> Self {
Self {
timeperiod,
period_f: timeperiod as f64,
prev_high: None,
prev_low: None,
prev_close: None,
bar_index: 0,
sum_tr: 0.0,
sum_pdm: 0.0,
sum_mdm: 0.0,
seeded: false,
}
}
#[inline]
fn dm_components(&self, high: f64, low: f64) -> (f64, f64, f64) {
let ph = self.prev_high.unwrap();
let pl = self.prev_low.unwrap();
let pc = self.prev_close.unwrap();
let hl = high - low;
let hc = (high - pc).abs();
let lc = (low - pc).abs();
let tr = hl.max(hc).max(lc);
let up = high - ph;
let down = pl - low;
let pdm = if up > down && up > 0.0 { up } else { 0.0 };
let mdm = if down > up && down > 0.0 { down } else { 0.0 };
(tr, pdm, mdm)
}
fn step(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64, f64)> {
let period = self.timeperiod;
if period < 1 {
return None;
}
if self.prev_high.is_none() {
self.prev_high = Some(high);
self.prev_low = Some(low);
self.prev_close = Some(close);
self.bar_index = 1;
return None;
}
let (tr, pdm, mdm) = self.dm_components(high, low);
self.prev_high = Some(high);
self.prev_low = Some(low);
self.prev_close = Some(close);
let i = self.bar_index;
self.bar_index += 1;
if !self.seeded {
if i < period {
self.sum_tr += tr;
self.sum_pdm += pdm;
self.sum_mdm += mdm;
return None;
}
self.seeded = true;
}
self.sum_tr = self.sum_tr - self.sum_tr / self.period_f + tr;
self.sum_pdm = self.sum_pdm - self.sum_pdm / self.period_f + pdm;
self.sum_mdm = self.sum_mdm - self.sum_mdm / self.period_f + mdm;
if self.sum_tr <= 0.0 {
return None;
}
let pdi = 100.0 * self.sum_pdm / self.sum_tr;
let mdi = 100.0 * self.sum_mdm / self.sum_tr;
let sum_di = pdi + mdi;
let dx = if sum_di > 0.0 {
100.0 * (pdi - mdi).abs() / sum_di
} else {
0.0
};
Some((pdi, mdi, dx))
}
}
#[derive(Debug, Clone)]
#[allow(non_camel_case_types)]
pub struct PLUS_DI {
pub timeperiod: usize,
core: DmiCore,
}
impl PLUS_DI {
pub fn new(timeperiod: usize) -> Self {
Self {
timeperiod,
core: DmiCore::new(timeperiod),
}
}
}
impl Next<(f64, f64, f64)> for PLUS_DI {
type Output = f64;
fn next(&mut self, (high, low, close): (f64, f64, f64)) -> Self::Output {
match self.core.step(high, low, close) {
Some((pdi, _, _)) => pdi,
None => f64::NAN,
}
}
}
#[derive(Debug, Clone)]
#[allow(non_camel_case_types)]
pub struct MINUS_DI {
pub timeperiod: usize,
core: DmiCore,
}
impl MINUS_DI {
pub fn new(timeperiod: usize) -> Self {
Self {
timeperiod,
core: DmiCore::new(timeperiod),
}
}
}
impl Next<(f64, f64, f64)> for MINUS_DI {
type Output = f64;
fn next(&mut self, (high, low, close): (f64, f64, f64)) -> Self::Output {
match self.core.step(high, low, close) {
Some((_, mdi, _)) => mdi,
None => f64::NAN,
}
}
}
#[derive(Debug, Clone)]
#[allow(non_camel_case_types)]
pub struct DX {
pub timeperiod: usize,
core: DmiCore,
}
impl DX {
pub fn new(timeperiod: usize) -> Self {
Self {
timeperiod,
core: DmiCore::new(timeperiod),
}
}
}
impl Next<(f64, f64, f64)> for DX {
type Output = f64;
fn next(&mut self, (high, low, close): (f64, f64, f64)) -> Self::Output {
match self.core.step(high, low, close) {
Some((_, _, dx)) => dx,
None => f64::NAN,
}
}
}
#[derive(Debug, Clone)]
#[allow(non_camel_case_types)]
pub struct ADX {
pub timeperiod: usize,
core: DmiCore,
dx_values: Vec<f64>,
adx: f64,
adx_ready: bool,
}
impl ADX {
pub fn new(timeperiod: usize) -> Self {
Self {
timeperiod,
core: DmiCore::new(timeperiod),
dx_values: Vec::new(),
adx: 0.0,
adx_ready: false,
}
}
}
impl Next<(f64, f64, f64)> for ADX {
type Output = f64;
fn next(&mut self, (high, low, close): (f64, f64, f64)) -> Self::Output {
let period = self.timeperiod;
if period < 2 {
return f64::NAN;
}
let Some((_, _, dx)) = self.core.step(high, low, close) else {
return f64::NAN;
};
let adx_start = 2 * period - 1;
let bar = self.core.bar_index.saturating_sub(1);
if bar < period {
return f64::NAN;
}
if bar < adx_start {
self.dx_values.push(dx);
return f64::NAN;
}
if bar == adx_start {
self.dx_values.push(dx);
let seed: f64 = self.dx_values.iter().sum::<f64>() / period as f64;
self.adx = seed;
self.adx_ready = true;
return seed;
}
if self.adx_ready {
self.adx = (self.adx * (period as f64 - 1.0) + dx) / period as f64;
return self.adx;
}
f64::NAN
}
}
#[derive(Debug, Clone)]
#[allow(non_camel_case_types)]
pub struct ADXR {
pub timeperiod: usize,
adx: ADX,
adx_history: Vec<f64>,
}
impl ADXR {
pub fn new(timeperiod: usize) -> Self {
Self {
timeperiod,
adx: ADX::new(timeperiod),
adx_history: Vec::new(),
}
}
}
impl Next<(f64, f64, f64)> for ADXR {
type Output = f64;
fn next(&mut self, (high, low, close): (f64, f64, f64)) -> Self::Output {
let period = self.timeperiod;
let adx_val = self.adx.next((high, low, close));
self.adx_history.push(adx_val);
let adxr_lookback = 3 * period - 2;
let bar = self.adx_history.len().saturating_sub(1);
if bar < adxr_lookback {
return f64::NAN;
}
if adx_val.is_nan() {
return f64::NAN;
}
let past_idx = bar + 1 - period;
let past = self.adx_history[past_idx];
if past.is_nan() {
return f64::NAN;
}
(adx_val + past) / 2.0
}
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
fn hlc(len: usize, h: &[f64], l: &[f64], c: &[f64]) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
let mut high = Vec::with_capacity(len);
let mut low = Vec::with_capacity(len);
let mut close = Vec::with_capacity(len);
for i in 0..len {
let val_h = h[i];
let val_l = l[i];
let val_c = c[i];
high.push(val_h.max(val_l).max(val_c));
low.push(val_h.min(val_l).min(val_c));
close.push(val_c);
}
(high, low, close)
}
proptest! {
#[test]
fn test_adx_parity(
h in prop::collection::vec(1.0..100.0, 1..100),
l in prop::collection::vec(1.0..100.0, 1..100),
c in prop::collection::vec(1.0..100.0, 1..100)
) {
let len = h.len().min(l.len()).min(c.len());
if len < 30 { return Ok(()); }
let (high, low, close) = hlc(len, &h, &l, &c);
let period = 14;
let mut adx = ADX::new(period);
let streaming: Vec<f64> = (0..len)
.map(|i| adx.next((high[i], low[i], close[i])))
.collect();
let batch = talib_rs::momentum::adx(&high, &low, &close, period)
.unwrap_or_else(|_| vec![f64::NAN; len]);
for (s, b) in streaming.iter().zip(batch.iter()) {
if s.is_nan() { assert!(b.is_nan()); }
else { approx::assert_relative_eq!(s, b, epsilon = 1e-6); }
}
}
#[test]
fn test_dx_parity(
h in prop::collection::vec(1.0..100.0, 1..100),
l in prop::collection::vec(1.0..100.0, 1..100),
c in prop::collection::vec(1.0..100.0, 1..100)
) {
let len = h.len().min(l.len()).min(c.len());
if len < 20 { return Ok(()); }
let (high, low, close) = hlc(len, &h, &l, &c);
let period = 14;
let mut dx = DX::new(period);
let streaming: Vec<f64> = (0..len)
.map(|i| dx.next((high[i], low[i], close[i])))
.collect();
let batch = talib_rs::momentum::dx(&high, &low, &close, period)
.unwrap_or_else(|_| vec![f64::NAN; len]);
for (s, b) in streaming.iter().zip(batch.iter()) {
if s.is_nan() { assert!(b.is_nan()); }
else { approx::assert_relative_eq!(s, b, epsilon = 1e-6); }
}
}
#[test]
fn test_plus_di_parity(
h in prop::collection::vec(1.0..100.0, 1..100),
l in prop::collection::vec(1.0..100.0, 1..100),
c in prop::collection::vec(1.0..100.0, 1..100)
) {
let len = h.len().min(l.len()).min(c.len());
if len < 20 { return Ok(()); }
let (high, low, close) = hlc(len, &h, &l, &c);
let period = 14;
let mut pdi = PLUS_DI::new(period);
let streaming: Vec<f64> = (0..len)
.map(|i| pdi.next((high[i], low[i], close[i])))
.collect();
let batch = talib_rs::momentum::plus_di(&high, &low, &close, period)
.unwrap_or_else(|_| vec![f64::NAN; len]);
for (s, b) in streaming.iter().zip(batch.iter()) {
if s.is_nan() { assert!(b.is_nan()); }
else { approx::assert_relative_eq!(s, b, epsilon = 1e-6); }
}
}
}
}