use crate::error::{Error, Result};
use crate::ohlcv::Candle;
use crate::traits::Indicator;
#[derive(Debug, Clone)]
pub struct Atr {
period: usize,
n_minus_1: f64,
inv_period: f64,
prev_close: Option<f64>,
seed_buf: Vec<f64>,
avg: f64,
seeded: bool,
}
impl Atr {
pub fn new(period: usize) -> Result<Self> {
if period == 0 {
return Err(Error::PeriodZero);
}
Ok(Self {
period,
n_minus_1: (period - 1) as f64,
inv_period: 1.0 / period as f64,
prev_close: None,
seed_buf: Vec::with_capacity(period),
avg: 0.0,
seeded: false,
})
}
pub const fn period(&self) -> usize {
self.period
}
pub const fn value(&self) -> Option<f64> {
if self.seeded {
Some(self.avg)
} else {
None
}
}
pub fn batch_atr(&mut self, high: &[f64], low: &[f64], close: &[f64]) -> Vec<f64> {
let p = self.period;
let n = high.len();
if self.seeded || !self.seed_buf.is_empty() || self.prev_close.is_some() || n < p {
let mut out = vec![f64::NAN; n];
for i in 0..n {
let candle = Candle::new_unchecked(close[i], high[i], low[i], close[i], 0.0, 0);
if let Some(v) = self.update(candle) {
out[i] = v;
}
}
return out;
}
let mut out = vec![f64::NAN; p - 1];
out.reserve(n - (p - 1));
let mut prev_close = close[0];
let mut sum_tr = high[0] - low[0];
self.seed_buf.push(sum_tr);
for i in 1..p {
let (h, l) = (high[i], low[i]);
let tr = (h - l)
.max((h - prev_close).abs())
.max((l - prev_close).abs());
prev_close = close[i];
self.seed_buf.push(tr);
sum_tr += tr;
}
let mut avg = sum_tr / p as f64;
out.push(avg);
for i in p..n {
let (h, l) = (high[i], low[i]);
let tr = (h - l)
.max((h - prev_close).abs())
.max((l - prev_close).abs());
prev_close = close[i];
avg = avg.mul_add(self.n_minus_1, tr) * self.inv_period;
out.push(avg);
}
self.prev_close = Some(prev_close);
self.avg = avg;
self.seeded = true;
out
}
}
impl Indicator for Atr {
type Input = Candle;
type Output = f64;
fn update(&mut self, candle: Candle) -> Option<f64> {
let tr = candle.true_range(self.prev_close);
self.prev_close = Some(candle.close);
if self.seeded {
let new_avg = self.avg.mul_add(self.n_minus_1, tr) * self.inv_period;
self.avg = new_avg;
return Some(new_avg);
}
self.seed_buf.push(tr);
if self.seed_buf.len() == self.period {
let seed = self.seed_buf.iter().copied().sum::<f64>() / self.period as f64;
self.avg = seed;
self.seeded = true;
return Some(seed);
}
None
}
fn reset(&mut self) {
self.prev_close = None;
self.seed_buf.clear();
self.avg = 0.0;
self.seeded = false;
}
fn warmup_period(&self) -> usize {
self.period
}
fn is_ready(&self) -> bool {
self.seeded
}
fn name(&self) -> &'static str {
"ATR"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::traits::BatchExt;
use approx::assert_relative_eq;
fn c(h: f64, l: f64, cl: f64) -> Candle {
Candle::new(cl, h, l, cl, 1.0, 0).unwrap()
}
fn atr_naive(hlc: &[(f64, f64, f64)], period: usize) -> Vec<Option<f64>> {
let n = period as f64;
let mut out = Vec::with_capacity(hlc.len());
let mut trs: Vec<f64> = Vec::new();
let mut avg: Option<f64> = None;
let mut prev_close: Option<f64> = None;
for &(h, l, cl) in hlc {
let tr = match prev_close {
None => h - l,
Some(pc) => (h - l).max((h - pc).abs()).max((l - pc).abs()),
};
prev_close = Some(cl);
if let Some(a) = avg {
let na = (a * (n - 1.0) + tr) / n;
avg = Some(na);
out.push(Some(na));
} else {
trs.push(tr);
if trs.len() == period {
avg = Some(trs.iter().sum::<f64>() / n);
out.push(avg);
} else {
out.push(None);
}
}
}
out
}
#[test]
fn rejects_zero_period() {
assert!(matches!(Atr::new(0), Err(Error::PeriodZero)));
}
#[test]
fn accessors_and_metadata() {
let mut atr = Atr::new(14).unwrap();
assert_eq!(atr.period(), 14);
assert_eq!(atr.name(), "ATR");
assert_eq!(atr.value(), None);
for _ in 0..14 {
atr.update(c(11.0, 9.0, 10.0));
}
assert!(atr.value().is_some());
}
#[test]
fn warmup_emits_on_period_th_candle() {
let candles = vec![
c(2.0, 1.0, 1.5),
c(3.0, 2.0, 2.5),
c(4.0, 3.0, 3.5),
c(5.0, 4.0, 4.5),
c(6.0, 5.0, 5.5),
];
let mut atr = Atr::new(3).unwrap();
let out = atr.batch(&candles);
assert!(out[0].is_none());
assert!(out[1].is_none());
assert!(out[2].is_some());
assert!(out[3].is_some());
}
#[test]
fn constant_range_yields_constant_atr() {
let candles: Vec<Candle> = (0..30).map(|_| c(11.0, 9.0, 10.0)).collect();
let mut atr = Atr::new(14).unwrap();
let out = atr.batch(&candles);
for v in out.iter().skip(13).flatten() {
assert_relative_eq!(*v, 2.0, epsilon = 1e-12);
}
}
#[test]
fn gap_up_uses_high_minus_prev_close() {
let candles = vec![
c(6.0, 4.0, 5.0), c(10.0, 9.0, 9.5), ];
let mut atr = Atr::new(2).unwrap();
let out = atr.batch(&candles);
assert_relative_eq!(out[1].unwrap(), 3.5, epsilon = 1e-12);
}
#[test]
fn batch_equals_streaming() {
let candles: Vec<Candle> = (0..40)
.map(|i| {
let mid = f64::from(i) + 10.0;
c(mid + 0.5, mid - 0.5, mid)
})
.collect();
let mut a = Atr::new(14).unwrap();
let mut b = Atr::new(14).unwrap();
assert_eq!(
a.batch(&candles),
candles.iter().map(|x| b.update(*x)).collect::<Vec<_>>()
);
}
#[test]
fn reset_clears_state() {
let candles: Vec<Candle> = (0..20).map(|_| c(11.0, 9.0, 10.0)).collect();
let mut atr = Atr::new(5).unwrap();
atr.batch(&candles);
assert!(atr.is_ready());
atr.reset();
assert!(!atr.is_ready());
assert_eq!(atr.update(candles[0]), None);
}
#[test]
fn never_negative() {
let candles: Vec<Candle> = (0..200)
.map(|i| {
let base = 100.0 + (f64::from(i) * 0.3).sin() * 5.0;
c(base + 1.0, base - 1.0, base)
})
.collect();
let mut atr = Atr::new(14).unwrap();
for v in atr.batch(&candles).into_iter().flatten() {
assert!(v >= 0.0, "ATR must be non-negative: {v}");
}
}
fn bits_eq(a: &[f64], b: &[f64]) -> bool {
a.len() == b.len()
&& a.iter()
.zip(b)
.all(|(x, y)| x == y || (x.is_nan() && y.is_nan()))
}
fn atr_replay(period: usize, high: &[f64], low: &[f64], close: &[f64]) -> Vec<f64> {
let mut a = Atr::new(period).unwrap();
(0..high.len())
.map(|i| {
let candle = Candle::new_unchecked(close[i], high[i], low[i], close[i], 0.0, 0);
a.update(candle).unwrap_or(f64::NAN)
})
.collect()
}
fn columns(n: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
let base: Vec<f64> = (0..n)
.map(|i| (f64::from(u32::try_from(i).unwrap()) * 0.3).sin() * 5.0 + 100.0)
.collect();
let high = base.iter().map(|b| b + 1.0).collect();
let low = base.iter().map(|b| b - 1.0).collect();
(high, low, base)
}
#[test]
fn batch_atr_fast_path_is_bit_identical() {
let (high, low, close) = columns(300);
let mut atr = Atr::new(14).unwrap();
let got = atr.batch_atr(&high, &low, &close);
assert!(bits_eq(&got, &atr_replay(14, &high, &low, &close)));
let mut ref_atr = Atr::new(14).unwrap();
for i in 0..high.len() {
ref_atr.update(Candle::new_unchecked(
close[i], high[i], low[i], close[i], 0.0, 0,
));
}
let next = Candle::new_unchecked(101.0, 102.0, 100.0, 101.0, 0.0, 0);
assert_eq!(atr.update(next), ref_atr.update(next));
}
#[test]
fn batch_atr_falls_back_when_not_fresh() {
let (high, low, close) = columns(40);
let mut atr = Atr::new(14).unwrap();
atr.update(Candle::new_unchecked(
close[0], high[0], low[0], close[0], 0.0, 0,
));
let mut ref_atr = Atr::new(14).unwrap();
ref_atr.update(Candle::new_unchecked(
close[0], high[0], low[0], close[0], 0.0, 0,
));
let want: Vec<f64> = (0..high.len())
.map(|i| {
ref_atr
.update(Candle::new_unchecked(
close[i], high[i], low[i], close[i], 0.0, 0,
))
.unwrap_or(f64::NAN)
})
.collect();
assert!(bits_eq(&atr.batch_atr(&high, &low, &close), &want));
}
#[test]
fn batch_atr_sub_period_slice_falls_back() {
let (high, low, close) = columns(5);
let mut atr = Atr::new(14).unwrap();
let got = atr.batch_atr(&high, &low, &close);
assert!(bits_eq(&got, &atr_replay(14, &high, &low, &close)));
assert!(got.iter().all(|x| x.is_nan()));
}
proptest::proptest! {
#![proptest_config(proptest::test_runner::Config::with_cases(48))]
#[test]
fn atr_matches_naive(
period in 1usize..15,
bars in proptest::collection::vec(
(10.0_f64..1000.0, 0.0_f64..50.0, 0.0_f64..1.0),
0..120,
),
) {
let hlc: Vec<(f64, f64, f64)> = bars
.iter()
.map(|&(low, range, frac)| (low + range, low, low + range * frac))
.collect();
let candles: Vec<Candle> = hlc.iter().map(|&(h, l, cl)| c(h, l, cl)).collect();
let mut atr = Atr::new(period).unwrap();
let got = atr.batch(&candles);
let want = atr_naive(&hlc, period);
proptest::prop_assert_eq!(got.len(), want.len());
for (g, w) in got.iter().zip(want.iter()) {
match (g, w) {
(None, None) => {}
(Some(a), Some(b)) => proptest::prop_assert!(
(a - b).abs() <= 1e-9 * a.abs().max(1.0),
"got={a} want={b}"
),
_ => proptest::prop_assert!(false, "warmup mismatch"),
}
}
}
}
}