use super::{IndicatorError, Result};
pub fn adx(highs: &[f64], lows: &[f64], closes: &[f64], period: usize) -> Result<Vec<Option<f64>>> {
if period == 0 {
return Err(IndicatorError::InvalidPeriod(
"Period must be greater than 0".to_string(),
));
}
let len = highs.len();
if lows.len() != len || closes.len() != len {
return Err(IndicatorError::InvalidPeriod(
"Data lengths must match".to_string(),
));
}
if len < 2 * period {
return Err(IndicatorError::InsufficientData {
need: 2 * period,
got: len,
});
}
let period_f = period as f64;
let period_m1 = (period - 1) as f64;
let mut s_tr = 0.0_f64;
let mut s_plus = 0.0_f64;
let mut s_minus = 0.0_f64;
for i in 1..=period {
let high_low = highs[i] - lows[i];
let high_close = (highs[i] - closes[i - 1]).abs();
let low_close = (lows[i] - closes[i - 1]).abs();
s_tr += high_low.max(high_close).max(low_close);
let up = highs[i] - highs[i - 1];
let dn = lows[i - 1] - lows[i];
if up > dn && up > 0.0 {
s_plus += up;
}
if dn > up && dn > 0.0 {
s_minus += dn;
}
}
s_tr /= period_f;
s_plus /= period_f;
s_minus /= period_f;
let tr_dm = |i: usize| -> (f64, f64, f64) {
let high_low = highs[i] - lows[i];
let high_close = (highs[i] - closes[i - 1]).abs();
let low_close = (lows[i] - closes[i - 1]).abs();
let tr = high_low.max(high_close).max(low_close);
let up = highs[i] - highs[i - 1];
let dn = lows[i - 1] - lows[i];
let plus = if up > dn && up > 0.0 { up } else { 0.0 };
let minus = if dn > up && dn > 0.0 { dn } else { 0.0 };
(tr, plus, minus)
};
let dx_at = |str: f64, sp: f64, sm: f64| -> f64 {
let p_di = if str != 0.0 { 100.0 * sp / str } else { 0.0 };
let m_di = if str != 0.0 { 100.0 * sm / str } else { 0.0 };
let di_sum = p_di + m_di;
if di_sum != 0.0 {
100.0 * (p_di - m_di).abs() / di_sum
} else {
0.0
}
};
let mut dx_sum = dx_at(s_tr, s_plus, s_minus);
for i in (period + 1)..=(2 * period - 1).min(len - 1) {
let (tr, plus, minus) = tr_dm(i);
s_tr = (s_tr * period_m1 + tr) / period_f;
s_plus = (s_plus * period_m1 + plus) / period_f;
s_minus = (s_minus * period_m1 + minus) / period_f;
dx_sum += dx_at(s_tr, s_plus, s_minus);
}
let mut result = vec![None; len];
let first_adx_idx = 2 * period - 1;
if first_adx_idx < len {
let mut adx = dx_sum / period_f;
result[first_adx_idx] = Some(adx);
for (i, slot) in result.iter_mut().enumerate().skip(first_adx_idx + 1) {
let (tr, plus, minus) = tr_dm(i);
s_tr = (s_tr * period_m1 + tr) / period_f;
s_plus = (s_plus * period_m1 + plus) / period_f;
s_minus = (s_minus * period_m1 + minus) / period_f;
adx = (adx * period_m1 + dx_at(s_tr, s_plus, s_minus)) / period_f;
*slot = Some(adx);
}
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_adx() {
let highs = vec![10.0; 30];
let lows = vec![8.0; 30];
let closes = vec![9.0; 30];
let result = adx(&highs, &lows, &closes, 14).unwrap();
assert_eq!(result.len(), 30);
assert!(result[26].is_none());
assert!(result[27].is_some());
}
}