use crate::error::{Error, Result};
use crate::indicators::adx::directional_movement;
use crate::ohlcv::Candle;
use crate::traits::Indicator;
#[derive(Debug, Clone)]
pub struct Dx {
period: usize,
prev: Option<Candle>,
plus_dm_seed: f64,
minus_dm_seed: f64,
tr_seed: f64,
seed_count: usize,
plus_dm_smooth: Option<f64>,
minus_dm_smooth: Option<f64>,
tr_smooth: Option<f64>,
}
impl Dx {
pub fn new(period: usize) -> Result<Self> {
if period == 0 {
return Err(Error::PeriodZero);
}
Ok(Self {
period,
prev: None,
plus_dm_seed: 0.0,
minus_dm_seed: 0.0,
tr_seed: 0.0,
seed_count: 0,
plus_dm_smooth: None,
minus_dm_smooth: None,
tr_smooth: None,
})
}
pub const fn period(&self) -> usize {
self.period
}
}
impl Indicator for Dx {
type Input = Candle;
type Output = f64;
fn update(&mut self, candle: Candle) -> Option<f64> {
let Some(prev) = self.prev else {
self.prev = Some(candle);
return None;
};
self.prev = Some(candle);
let (plus_dm, minus_dm) = directional_movement(&prev, &candle);
let tr = candle.true_range(Some(prev.close));
let n = self.period as f64;
let (plus_v, minus_v, tr_v) = if let (Some(p), Some(m), Some(t)) =
(self.plus_dm_smooth, self.minus_dm_smooth, self.tr_smooth)
{
let p_new = p - p / n + plus_dm;
let m_new = m - m / n + minus_dm;
let t_new = t - t / n + tr;
self.plus_dm_smooth = Some(p_new);
self.minus_dm_smooth = Some(m_new);
self.tr_smooth = Some(t_new);
(p_new, m_new, t_new)
} else {
self.plus_dm_seed += plus_dm;
self.minus_dm_seed += minus_dm;
self.tr_seed += tr;
self.seed_count += 1;
if self.seed_count < self.period {
return None;
}
self.plus_dm_smooth = Some(self.plus_dm_seed);
self.minus_dm_smooth = Some(self.minus_dm_seed);
self.tr_smooth = Some(self.tr_seed);
(self.plus_dm_seed, self.minus_dm_seed, self.tr_seed)
};
let (plus_di, minus_di) = if tr_v == 0.0 {
(0.0, 0.0)
} else {
(100.0 * plus_v / tr_v, 100.0 * minus_v / tr_v)
};
let di_sum = plus_di + minus_di;
let dx = if di_sum == 0.0 {
0.0
} else {
100.0 * (plus_di - minus_di).abs() / di_sum
};
Some(dx)
}
fn reset(&mut self) {
self.prev = None;
self.plus_dm_seed = 0.0;
self.minus_dm_seed = 0.0;
self.tr_seed = 0.0;
self.seed_count = 0;
self.plus_dm_smooth = None;
self.minus_dm_smooth = None;
self.tr_smooth = None;
}
fn warmup_period(&self) -> usize {
self.period
}
fn is_ready(&self) -> bool {
self.tr_smooth.is_some()
}
fn name(&self) -> &'static str {
"DX"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::traits::BatchExt;
use approx::assert_relative_eq;
fn c(h: f64, l: f64, cl: f64) -> Candle {
Candle::new(cl, h, l, cl, 1.0, 0).unwrap()
}
#[test]
fn rejects_zero_period() {
assert!(matches!(Dx::new(0), Err(Error::PeriodZero)));
}
#[test]
fn accessors_report_config() {
let dx = Dx::new(7).unwrap();
assert_eq!(dx.period(), 7);
assert_eq!(dx.name(), "DX");
assert_eq!(dx.warmup_period(), 7);
assert!(!dx.is_ready());
}
#[test]
fn strong_trend_drives_dx_high() {
let candles: Vec<Candle> = (0..12)
.map(|i| {
let base = 100.0 + f64::from(i) * 2.0;
c(base + 1.0, base - 0.5, base + 0.5)
})
.collect();
let mut dx = Dx::new(3).unwrap();
let out: Vec<Option<f64>> = dx.batch(&candles);
assert_eq!(out[0], None);
assert!(out[3].is_some());
let last = out.into_iter().flatten().last().unwrap();
assert!(last > 50.0 && last <= 100.0);
assert!(dx.is_ready());
}
#[test]
fn flat_market_returns_zero() {
let candles: Vec<Candle> = (0..6).map(|_| c(50.0, 50.0, 50.0)).collect();
let mut dx = Dx::new(3).unwrap();
let last = dx.batch(&candles).into_iter().flatten().last().unwrap();
assert_relative_eq!(last, 0.0, epsilon = 1e-12);
}
#[test]
fn balanced_directional_movement_is_low() {
let candles: Vec<Candle> = (0..30)
.map(|i| {
let base = if i % 2 == 0 { 100.0 } else { 101.0 };
c(base + 1.0, base - 1.0, base)
})
.collect();
let mut dx = Dx::new(5).unwrap();
let last = dx.batch(&candles).into_iter().flatten().last().unwrap();
assert!((0.0..=100.0).contains(&last));
}
#[test]
fn reset_restores_initial_state() {
let candles: Vec<Candle> = (0..6)
.map(|i| {
let base = 100.0 + f64::from(i) * 2.0;
c(base + 1.0, base - 0.5, base + 0.5)
})
.collect();
let mut dx = Dx::new(3).unwrap();
let _ = dx.batch(&candles);
assert!(dx.is_ready());
dx.reset();
assert!(!dx.is_ready());
assert_eq!(dx.update(candles[0]), None);
}
}