Skip to main content

indicators/trend/
atr.rs

1//! Average True Range (ATR).
2//!
3//! Python source: `indicators/trend/volatility/atr.py :: class ATR`
4//!
5//! # Python algorithm (to port)
6//! ```python
7//! high_low        = data["high"] - data["low"]
8//! high_close_prev = abs(data["high"] - data["close"].shift(1))
9//! low_close_prev  = abs(data["low"]  - data["close"].shift(1))
10//! tr  = pd.concat([high_low, high_close_prev, low_close_prev], axis=1).max(axis=1)
11//! atr = tr.rolling(period).mean()           # method=="sma"
12//! # or:
13//! atr = tr.ewm(span=period, adjust=False).mean()  # method=="ema"
14//!
15//! normalized_atr = atr / data["close"] * 100   # percentage
16//! ```
17//!
18//! Output columns: `"ATR_{period}"`, `"ATR_{period}_normalized"`.
19//!
20//! See also: `crate::functions::atr()` and `crate::functions::true_range()`.
21
22use std::collections::HashMap;
23
24use crate::error::IndicatorError;
25use crate::functions::{self};
26use crate::indicator::{Indicator, IndicatorOutput};
27use crate::registry::{param_str, param_usize};
28use crate::types::Candle;
29
30// ── Params ────────────────────────────────────────────────────────────────────
31
32#[derive(Debug, Clone, PartialEq, Eq)]
33pub enum AtrMethod {
34    Sma,
35    Ema,
36}
37
38#[derive(Debug, Clone)]
39pub struct AtrParams {
40    /// Period.  Python default: 14.
41    pub period: usize,
42    /// Smoothing method.  Python default: `"sma"`.
43    pub method: AtrMethod,
44}
45
46impl Default for AtrParams {
47    fn default() -> Self {
48        Self {
49            period: 14,
50            method: AtrMethod::Sma,
51        }
52    }
53}
54
55// ── Indicator struct ──────────────────────────────────────────────────────────
56
57/// Average True Range indicator.
58///
59/// Note: smoothing is SMA (default) or EMA for Python-parity (`AtrParams::method`),
60/// **not** Wilder's RMA. A Wilder-smoothed ATR will produce different values.
61#[derive(Debug, Clone)]
62pub struct Atr {
63    pub params: AtrParams,
64}
65
66impl Atr {
67    pub fn new(params: AtrParams) -> Self {
68        Self { params }
69    }
70    pub fn with_period(period: usize) -> Self {
71        Self::new(AtrParams {
72            period,
73            ..Default::default()
74        })
75    }
76
77    fn output_key(&self) -> String {
78        format!("ATR_{}", self.params.period)
79    }
80    fn norm_key(&self) -> String {
81        format!("ATR_{}_normalized", self.params.period)
82    }
83}
84
85impl Indicator for Atr {
86    fn name(&self) -> &'static str {
87        "ATR"
88    }
89    fn required_len(&self) -> usize {
90        self.params.period + 1
91    } // need prev close
92    fn required_columns(&self) -> &[&'static str] {
93        &["high", "low", "close"]
94    }
95
96    /// Ports the Python ATR calculation.
97    ///
98    /// True range = `max(H−L, |H−prev_C|, |L−prev_C|)`.  For the first bar
99    /// there is no previous close, so `functions::true_range` is expected to
100    /// use `H−L` alone (matching pandas `skipna=True` max behaviour).
101    ///
102    /// SMA path: `tr.rolling(period).mean()` — `NaN` for first `period` bars.
103    /// EMA path: `tr.ewm(span=period, adjust=False).mean()` — value from bar 0.
104    ///
105    /// Normalised ATR = `atr / close * 100` (percentage of price).
106    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
107        self.check_len(candles)?;
108
109        let high: Vec<f64> = candles.iter().map(|c| c.high).collect();
110        let low: Vec<f64> = candles.iter().map(|c| c.low).collect();
111        let close: Vec<f64> = candles.iter().map(|c| c.close).collect();
112
113        let tr = functions::true_range(&high, &low, &close)?;
114
115        let atr_vals = match self.params.method {
116            // ema_nan_aware seeds from the first TR value (adjust=False),
117            // matching Python's tr.ewm(span=period, adjust=False).mean().
118            AtrMethod::Ema => functions::ema_nan_aware(&tr, self.params.period)?,
119            AtrMethod::Sma => functions::sma(&tr, self.params.period)?,
120        };
121
122        let norm: Vec<f64> = atr_vals
123            .iter()
124            .zip(&close)
125            .map(|(&a, &c)| if c == 0.0 { f64::NAN } else { a / c * 100.0 })
126            .collect();
127
128        Ok(IndicatorOutput::from_pairs([
129            (self.output_key(), atr_vals),
130            (self.norm_key(), norm),
131        ]))
132    }
133}
134
135// ── Registry factory ──────────────────────────────────────────────────────────
136
137pub fn factory<S: ::std::hash::BuildHasher>(
138    params: &HashMap<String, String, S>,
139) -> Result<Box<dyn Indicator>, IndicatorError> {
140    let period = param_usize(params, "period", 14)?;
141    let method = match param_str(params, "method", "sma") {
142        "ema" => AtrMethod::Ema,
143        _ => AtrMethod::Sma,
144    };
145    Ok(Box::new(Atr::new(AtrParams { period, method })))
146}
147
148// ── Tests ─────────────────────────────────────────────────────────────────────
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153
154    fn candles(data: &[(f64, f64, f64)]) -> Vec<Candle> {
155        data.iter()
156            .enumerate()
157            .map(|(i, &(h, l, c))| Candle {
158                time: i64::try_from(i).expect("time index fits i64"),
159                open: c,
160                high: h,
161                low: l,
162                close: c,
163                volume: 1.0,
164            })
165            .collect()
166    }
167
168    #[test]
169    fn atr_output_has_both_columns() {
170        let bars: Vec<(f64, f64, f64)> = (1..=20)
171            .map(|i| (i as f64 + 1.0, i as f64 - 1.0, i as f64))
172            .collect();
173        let atr = Atr::with_period(5);
174        let out = atr.calculate(&candles(&bars)).unwrap();
175        assert!(out.get("ATR_5").is_some());
176        assert!(out.get("ATR_5_normalized").is_some());
177    }
178
179    #[test]
180    fn atr_insufficient_data() {
181        assert!(
182            Atr::with_period(14)
183                .calculate(&candles(&[(10.0, 8.0, 9.0)]))
184                .is_err()
185        );
186    }
187
188    #[test]
189    fn atr_normalized_is_percentage() {
190        let bars: Vec<(f64, f64, f64)> = (1..=20)
191            .map(|i| (i as f64 + 1.0, i as f64 - 1.0, i as f64))
192            .collect();
193        let atr = Atr::with_period(5);
194        let out = atr.calculate(&candles(&bars)).unwrap();
195        let atr_vals = out.get("ATR_5").unwrap();
196        let norm_vals = out.get("ATR_5_normalized").unwrap();
197        let close: Vec<f64> = bars.iter().map(|&(_, _, c)| c).collect();
198        for i in 0..bars.len() {
199            if !atr_vals[i].is_nan() {
200                let expected = atr_vals[i] / close[i] * 100.0;
201                assert!((norm_vals[i] - expected).abs() < 1e-9);
202            }
203        }
204    }
205
206    #[test]
207    fn factory_creates_atr() {
208        let params = [
209            ("period".into(), "14".into()),
210            ("method".into(), "ema".into()),
211        ]
212        .into();
213        let ind = factory(&params).unwrap();
214        assert_eq!(ind.name(), "ATR");
215    }
216}