Skip to main content

quantwave_core/indicators/incremental/
ta_atr.rs

1//! Native O(1) ATR — TA-Lib Wilder smoothing on true range.
2
3use crate::traits::Next;
4
5/// Average True Range — matches `talib_rs::volatility::atr`.
6#[derive(Debug, Clone)]
7#[allow(non_camel_case_types)]
8pub struct TaATR {
9    pub timeperiod: usize,
10    period_f: f64,
11    prev_close: Option<f64>,
12    bars_seen: usize,
13    warmup_tr_count: usize,
14    warmup_sum: f64,
15    atr: f64,
16}
17
18impl TaATR {
19    pub fn new(timeperiod: usize) -> Self {
20        Self {
21            timeperiod,
22            period_f: timeperiod as f64,
23            prev_close: None,
24            bars_seen: 0,
25            warmup_tr_count: 0,
26            warmup_sum: 0.0,
27            atr: 0.0,
28        }
29    }
30
31    #[inline]
32    fn true_range(&self, high: f64, low: f64, prev_close: f64) -> f64 {
33        let hl = high - low;
34        let hc = (high - prev_close).abs();
35        let lc = (low - prev_close).abs();
36        hl.max(hc).max(lc)
37    }
38}
39
40impl Next<(f64, f64, f64)> for TaATR {
41    type Output = f64;
42
43    fn next(&mut self, (high, low, close): (f64, f64, f64)) -> Self::Output {
44        let period = self.timeperiod;
45        if period < 1 {
46            return f64::NAN;
47        }
48
49        if self.bars_seen == 0 {
50            self.prev_close = Some(close);
51            self.bars_seen = 1;
52            return f64::NAN;
53        }
54
55        let pc = self.prev_close.unwrap();
56        let tr = self.true_range(high, low, pc);
57        self.prev_close = Some(close);
58        self.bars_seen += 1;
59
60        if self.warmup_tr_count < period {
61            self.warmup_tr_count += 1;
62            self.warmup_sum += tr;
63            if self.warmup_tr_count < period {
64                return f64::NAN;
65            }
66            self.atr = self.warmup_sum / self.period_f;
67            return self.atr;
68        }
69
70        self.atr = (self.atr * (self.period_f - 1.0) + tr) / self.period_f;
71        self.atr
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78    use proptest::prelude::*;
79
80    proptest! {
81        #[test]
82        fn test_ta_atr_parity(
83            h in prop::collection::vec(1.0..100.0, 1..100),
84            l in prop::collection::vec(1.0..100.0, 1..100),
85            c in prop::collection::vec(1.0..100.0, 1..100)
86        ) {
87            let len = h.len().min(l.len()).min(c.len());
88            if len == 0 { return Ok(()); }
89            let mut high = Vec::with_capacity(len);
90            let mut low = Vec::with_capacity(len);
91            let mut close = Vec::with_capacity(len);
92            for i in 0..len {
93                let v_h: f64 = h[i];
94                let v_l: f64 = l[i];
95                let v_c: f64 = c[i];
96                high.push(v_h.max(v_l).max(v_c));
97                low.push(v_h.min(v_l).min(v_c));
98                close.push(v_c);
99            }
100
101            let period = 14;
102            let mut ta_atr = TaATR::new(period);
103            let streaming_results: Vec<f64> =
104                (0..len).map(|i| ta_atr.next((high[i], low[i], close[i]))).collect();
105            let batch_results = talib_rs::volatility::atr(&high, &low, &close, period)
106                .unwrap_or_else(|_| vec![f64::NAN; len]);
107
108            for (s, b) in streaming_results.iter().zip(batch_results.iter()) {
109                if s.is_nan() {
110                    assert!(b.is_nan());
111                } else {
112                    approx::assert_relative_eq!(s, b, epsilon = 1e-6);
113                }
114            }
115        }
116    }
117}