Skip to main content

indicators/trend/
atr.rs

1//! Average True Range (ATR).
2//!
3//! Python source: `indicators/trend/volatility/atr.py :: class ATR`
4//!
5//! # Python algorithm (to port)
6//! ```python
7//! high_low        = data["high"] - data["low"]
8//! high_close_prev = abs(data["high"] - data["close"].shift(1))
9//! low_close_prev  = abs(data["low"]  - data["close"].shift(1))
10//! tr  = pd.concat([high_low, high_close_prev, low_close_prev], axis=1).max(axis=1)
11//! atr = tr.rolling(period).mean()           # method=="sma"
12//! # or:
13//! atr = tr.ewm(span=period, adjust=False).mean()  # method=="ema"
14//!
15//! normalized_atr = atr / data["close"] * 100   # percentage
16//! ```
17//!
18//! Output columns: `"ATR_{period}"`, `"ATR_{period}_normalized"`.
19//!
20//! See also: `crate::functions::atr()` and `crate::functions::true_range()`.
21
22use std::collections::HashMap;
23
24use crate::functions::{self};
25use crate::error::IndicatorError;
26use crate::indicator::{Indicator, IndicatorOutput};
27use crate::registry::{param_usize, param_str};
28use crate::types::Candle;
29
30// ── Params ────────────────────────────────────────────────────────────────────
31
32#[derive(Debug, Clone, PartialEq, Eq)]
33pub enum AtrMethod {
34    Sma,
35    Ema,
36}
37
38#[derive(Debug, Clone)]
39pub struct AtrParams {
40    /// Period.  Python default: 14.
41    pub period: usize,
42    /// Smoothing method.  Python default: `"sma"`.
43    pub method: AtrMethod,
44}
45
46impl Default for AtrParams {
47    fn default() -> Self {
48        Self {
49            period: 14,
50            method: AtrMethod::Sma,
51        }
52    }
53}
54
55// ── Indicator struct ──────────────────────────────────────────────────────────
56
57#[derive(Debug, Clone)]
58pub struct Atr {
59    pub params: AtrParams,
60}
61
62impl Atr {
63    pub fn new(params: AtrParams) -> Self {
64        Self { params }
65    }
66    pub fn with_period(period: usize) -> Self {
67        Self::new(AtrParams {
68            period,
69            ..Default::default()
70        })
71    }
72
73    fn output_key(&self) -> String {
74        format!("ATR_{}", self.params.period)
75    }
76    fn norm_key(&self) -> String {
77        format!("ATR_{}_normalized", self.params.period)
78    }
79}
80
81impl Indicator for Atr {
82    fn name(&self) -> &str {
83        "ATR"
84    }
85    fn required_len(&self) -> usize {
86        self.params.period + 1
87    } // need prev close
88    fn required_columns(&self) -> &[&'static str] {
89        &["high", "low", "close"]
90    }
91
92    /// TODO: port Python SMA/EMA-smoothed ATR + normalized output.
93    ///
94    /// `crate::functions::atr()` already implements EMA-smoothed ATR.
95    /// For SMA-smoothed, roll the true range manually.
96    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
97        self.check_len(candles)?;
98
99        let high: Vec<f64> = candles.iter().map(|c| c.high).collect();
100        let low: Vec<f64> = candles.iter().map(|c| c.low).collect();
101        let close: Vec<f64> = candles.iter().map(|c| c.close).collect();
102
103        let tr = functions::true_range(&high, &low, &close)?;
104
105        let atr_vals = match self.params.method {
106            AtrMethod::Ema => functions::ema(&tr, self.params.period)?,
107            AtrMethod::Sma => functions::sma(&tr, self.params.period)?,
108        };
109
110        let norm: Vec<f64> = atr_vals.iter().zip(&close)
111            .map(|(&a, &c)| if c == 0.0 { f64::NAN } else { a / c * 100.0 })
112            .collect();
113
114        Ok(IndicatorOutput::from_pairs([
115            (self.output_key(), atr_vals),
116            (self.norm_key(), norm),
117        ]))
118    }
119}
120
121// ── Registry factory ──────────────────────────────────────────────────────────
122
123pub fn factory(params: &HashMap<String, String>) -> Result<Box<dyn Indicator>, IndicatorError> {
124    let period = param_usize(params, "period", 14)?;
125    let method = match param_str(params, "method", "sma") {
126        "ema" => AtrMethod::Ema,
127        _ => AtrMethod::Sma,
128    };
129    Ok(Box::new(Atr::new(AtrParams { period, method })))
130}
131
132// ── Tests ─────────────────────────────────────────────────────────────────────
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137
138    fn candles(data: &[(f64, f64, f64)]) -> Vec<Candle> {
139        data.iter().enumerate().map(|(i, &(h, l, c))| Candle {
140            time: i as i64, open: c, high: h, low: l, close: c, volume: 1.0,
141        }).collect()
142    }
143
144    #[test]
145    fn atr_output_has_both_columns() {
146        let bars: Vec<(f64, f64, f64)> = (1..=20).map(|i| (i as f64 + 1.0, i as f64 - 1.0, i as f64)).collect();
147        let atr = Atr::with_period(5);
148        let out = atr.calculate(&candles(&bars)).unwrap();
149        assert!(out.get("ATR_5").is_some());
150        assert!(out.get("ATR_5_normalized").is_some());
151    }
152
153    #[test]
154    fn atr_insufficient_data() {
155        assert!(Atr::with_period(14).calculate(&candles(&[(10.0, 8.0, 9.0)])).is_err());
156    }
157
158    #[test]
159    fn atr_normalized_is_percentage() {
160        let bars: Vec<(f64, f64, f64)> = (1..=20).map(|i| (i as f64 + 1.0, i as f64 - 1.0, i as f64)).collect();
161        let atr = Atr::with_period(5);
162        let out = atr.calculate(&candles(&bars)).unwrap();
163        let atr_vals = out.get("ATR_5").unwrap();
164        let norm_vals = out.get("ATR_5_normalized").unwrap();
165        let close: Vec<f64> = bars.iter().map(|&(_, _, c)| c).collect();
166        for i in 0..bars.len() {
167            if !atr_vals[i].is_nan() {
168                let expected = atr_vals[i] / close[i] * 100.0;
169                assert!((norm_vals[i] - expected).abs() < 1e-9);
170            }
171        }
172    }
173
174    #[test]
175    fn factory_creates_atr() {
176        let params = [("period".into(), "14".into()), ("method".into(), "ema".into())].into();
177        let ind = factory(&params).unwrap();
178        assert_eq!(ind.name(), "ATR");
179    }
180}