1use std::collections::HashMap;
23
24use crate::error::IndicatorError;
25use crate::functions::{self};
26use crate::indicator::{Indicator, IndicatorOutput};
27use crate::registry::{param_str, param_usize};
28use crate::types::Candle;
29
30#[derive(Debug, Clone, PartialEq, Eq)]
33pub enum AtrMethod {
34 Sma,
35 Ema,
36}
37
38#[derive(Debug, Clone)]
39pub struct AtrParams {
40 pub period: usize,
42 pub method: AtrMethod,
44}
45
46impl Default for AtrParams {
47 fn default() -> Self {
48 Self {
49 period: 14,
50 method: AtrMethod::Sma,
51 }
52 }
53}
54
55#[derive(Debug, Clone)]
62pub struct Atr {
63 pub params: AtrParams,
64}
65
66impl Atr {
67 pub fn new(params: AtrParams) -> Self {
68 Self { params }
69 }
70 pub fn with_period(period: usize) -> Self {
71 Self::new(AtrParams {
72 period,
73 ..Default::default()
74 })
75 }
76
77 fn output_key(&self) -> String {
78 format!("ATR_{}", self.params.period)
79 }
80 fn norm_key(&self) -> String {
81 format!("ATR_{}_normalized", self.params.period)
82 }
83}
84
85impl Indicator for Atr {
86 fn name(&self) -> &'static str {
87 "ATR"
88 }
89 fn required_len(&self) -> usize {
90 self.params.period + 1
91 } fn required_columns(&self) -> &[&'static str] {
93 &["high", "low", "close"]
94 }
95
96 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
107 self.check_len(candles)?;
108
109 let high: Vec<f64> = candles.iter().map(|c| c.high).collect();
110 let low: Vec<f64> = candles.iter().map(|c| c.low).collect();
111 let close: Vec<f64> = candles.iter().map(|c| c.close).collect();
112
113 let tr = functions::true_range(&high, &low, &close)?;
114
115 let atr_vals = match self.params.method {
116 AtrMethod::Ema => functions::ema_nan_aware(&tr, self.params.period)?,
119 AtrMethod::Sma => functions::sma(&tr, self.params.period)?,
120 };
121
122 let norm: Vec<f64> = atr_vals
123 .iter()
124 .zip(&close)
125 .map(|(&a, &c)| if c == 0.0 { f64::NAN } else { a / c * 100.0 })
126 .collect();
127
128 Ok(IndicatorOutput::from_pairs([
129 (self.output_key(), atr_vals),
130 (self.norm_key(), norm),
131 ]))
132 }
133}
134
135pub fn factory<S: ::std::hash::BuildHasher>(
138 params: &HashMap<String, String, S>,
139) -> Result<Box<dyn Indicator>, IndicatorError> {
140 let period = param_usize(params, "period", 14)?;
141 let method = match param_str(params, "method", "sma") {
142 "ema" => AtrMethod::Ema,
143 _ => AtrMethod::Sma,
144 };
145 Ok(Box::new(Atr::new(AtrParams { period, method })))
146}
147
148#[cfg(test)]
151mod tests {
152 use super::*;
153
154 fn candles(data: &[(f64, f64, f64)]) -> Vec<Candle> {
155 data.iter()
156 .enumerate()
157 .map(|(i, &(h, l, c))| Candle {
158 time: i64::try_from(i).expect("time index fits i64"),
159 open: c,
160 high: h,
161 low: l,
162 close: c,
163 volume: 1.0,
164 })
165 .collect()
166 }
167
168 #[test]
169 fn atr_output_has_both_columns() {
170 let bars: Vec<(f64, f64, f64)> = (1..=20)
171 .map(|i| (i as f64 + 1.0, i as f64 - 1.0, i as f64))
172 .collect();
173 let atr = Atr::with_period(5);
174 let out = atr.calculate(&candles(&bars)).unwrap();
175 assert!(out.get("ATR_5").is_some());
176 assert!(out.get("ATR_5_normalized").is_some());
177 }
178
179 #[test]
180 fn atr_insufficient_data() {
181 assert!(
182 Atr::with_period(14)
183 .calculate(&candles(&[(10.0, 8.0, 9.0)]))
184 .is_err()
185 );
186 }
187
188 #[test]
189 fn atr_normalized_is_percentage() {
190 let bars: Vec<(f64, f64, f64)> = (1..=20)
191 .map(|i| (i as f64 + 1.0, i as f64 - 1.0, i as f64))
192 .collect();
193 let atr = Atr::with_period(5);
194 let out = atr.calculate(&candles(&bars)).unwrap();
195 let atr_vals = out.get("ATR_5").unwrap();
196 let norm_vals = out.get("ATR_5_normalized").unwrap();
197 let close: Vec<f64> = bars.iter().map(|&(_, _, c)| c).collect();
198 for i in 0..bars.len() {
199 if !atr_vals[i].is_nan() {
200 let expected = atr_vals[i] / close[i] * 100.0;
201 assert!((norm_vals[i] - expected).abs() < 1e-9);
202 }
203 }
204 }
205
206 #[test]
207 fn factory_creates_atr() {
208 let params = [
209 ("period".into(), "14".into()),
210 ("method".into(), "ema".into()),
211 ]
212 .into();
213 let ind = factory(¶ms).unwrap();
214 assert_eq!(ind.name(), "ATR");
215 }
216}