mantis_ta/indicators/volatility/
atr.rs1use crate::indicators::Indicator;
2use crate::types::Candle;
3
4#[derive(Debug, Clone)]
34pub struct ATR {
35 period: usize,
36 prev_close: Option<f64>,
37 count: usize,
38 sum_tr: f64,
39 atr: Option<f64>,
40}
41
42impl ATR {
43 pub fn new(period: usize) -> Self {
44 assert!(period > 0, "period must be > 0");
45 Self {
46 period,
47 prev_close: None,
48 count: 0,
49 sum_tr: 0.0,
50 atr: None,
51 }
52 }
53
54 #[inline]
55 fn true_range(&self, candle: &Candle) -> f64 {
56 let hl = candle.high - candle.low;
57 match self.prev_close {
58 None => hl,
59 Some(prev) => {
60 let h_pc = (candle.high - prev).abs();
61 let l_pc = (candle.low - prev).abs();
62 hl.max(h_pc).max(l_pc)
63 }
64 }
65 }
66}
67
68impl Indicator for ATR {
69 type Output = f64;
70
71 fn next(&mut self, candle: &Candle) -> Option<Self::Output> {
72 let tr = self.true_range(candle);
73
74 let output = if let Some(prev_atr) = self.atr {
75 let next_atr = (prev_atr * (self.period as f64 - 1.0) + tr) / self.period as f64;
76 self.atr = Some(next_atr);
77 self.atr
78 } else {
79 self.sum_tr += tr;
80 self.count += 1;
81 if self.count >= self.period {
82 let initial = self.sum_tr / self.period as f64;
83 self.atr = Some(initial);
84 self.atr
85 } else {
86 None
87 }
88 };
89
90 self.prev_close = Some(candle.close);
91 output
92 }
93
94 fn reset(&mut self) {
95 self.prev_close = None;
96 self.count = 0;
97 self.sum_tr = 0.0;
98 self.atr = None;
99 }
100
101 fn warmup_period(&self) -> usize {
102 self.period
103 }
104
105 fn clone_boxed(&self) -> Box<dyn Indicator<Output = Self::Output>> {
106 Box::new(self.clone())
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113
114 #[test]
115 fn atr_emits_after_warmup() {
116 let mut atr = ATR::new(3);
117 let candles: Vec<Candle> = [
118 (1.0, 0.5, 0.8),
119 (2.0, 0.5, 1.5),
120 (3.0, 1.0, 2.5),
121 (3.5, 1.5, 3.0),
122 ]
123 .iter()
124 .map(|(h, l, c)| Candle {
125 timestamp: 0,
126 open: *c,
127 high: *h,
128 low: *l,
129 close: *c,
130 volume: 0.0,
131 })
132 .collect();
133
134 let outputs: Vec<_> = candles.iter().map(|c| atr.next(c)).collect();
135 let wp = atr.warmup_period();
136 assert!(outputs.iter().take(wp - 1).all(|o| o.is_none()));
138 assert!(outputs[wp - 1].is_some());
140 }
141}