Skip to main content

mantis_ta/indicators/volatility/
atr.rs

1use crate::indicators::Indicator;
2use crate::types::Candle;
3
4/// Average True Range using Wilder's smoothing.
5///
6/// # Examples
7/// ```rust
8/// use mantis_ta::indicators::{Indicator, ATR};
9/// use mantis_ta::types::Candle;
10///
11/// let candles: Vec<Candle> = [
12///     (1.0, 0.5, 0.8),
13///     (2.0, 0.5, 1.5),
14///     (3.0, 1.0, 2.5),
15///     (3.5, 1.5, 3.0),
16/// ]
17/// .iter()
18/// .enumerate()
19/// .map(|(i, (h, l, c))| Candle {
20///     timestamp: i as i64,
21///     open: *c,
22///     high: *h,
23///     low: *l,
24///     close: *c,
25///     volume: 0.0,
26/// })
27/// .collect();
28///
29/// let out = ATR::new(3).calculate(&candles);
30/// assert!(out.iter().take(2).all(|v| v.is_none()));
31/// assert!(out[2].is_some());
32/// ```
33#[derive(Debug, Clone)]
34pub struct ATR {
35    period: usize,
36    prev_close: Option<f64>,
37    count: usize,
38    sum_tr: f64,
39    atr: Option<f64>,
40}
41
42impl ATR {
43    pub fn new(period: usize) -> Self {
44        assert!(period > 0, "period must be > 0");
45        Self {
46            period,
47            prev_close: None,
48            count: 0,
49            sum_tr: 0.0,
50            atr: None,
51        }
52    }
53
54    #[inline]
55    fn true_range(&self, candle: &Candle) -> f64 {
56        let hl = candle.high - candle.low;
57        match self.prev_close {
58            None => hl,
59            Some(prev) => {
60                let h_pc = (candle.high - prev).abs();
61                let l_pc = (candle.low - prev).abs();
62                hl.max(h_pc).max(l_pc)
63            }
64        }
65    }
66}
67
68impl Indicator for ATR {
69    type Output = f64;
70
71    fn next(&mut self, candle: &Candle) -> Option<Self::Output> {
72        let tr = self.true_range(candle);
73
74        let output = if let Some(prev_atr) = self.atr {
75            let next_atr = (prev_atr * (self.period as f64 - 1.0) + tr) / self.period as f64;
76            self.atr = Some(next_atr);
77            self.atr
78        } else {
79            self.sum_tr += tr;
80            self.count += 1;
81            if self.count >= self.period {
82                let initial = self.sum_tr / self.period as f64;
83                self.atr = Some(initial);
84                self.atr
85            } else {
86                None
87            }
88        };
89
90        self.prev_close = Some(candle.close);
91        output
92    }
93
94    fn reset(&mut self) {
95        self.prev_close = None;
96        self.count = 0;
97        self.sum_tr = 0.0;
98        self.atr = None;
99    }
100
101    fn warmup_period(&self) -> usize {
102        self.period
103    }
104
105    fn clone_boxed(&self) -> Box<dyn Indicator<Output = Self::Output>> {
106        Box::new(self.clone())
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113
114    #[test]
115    fn atr_emits_after_warmup() {
116        let mut atr = ATR::new(3);
117        let candles: Vec<Candle> = [
118            (1.0, 0.5, 0.8),
119            (2.0, 0.5, 1.5),
120            (3.0, 1.0, 2.5),
121            (3.5, 1.5, 3.0),
122        ]
123        .iter()
124        .map(|(h, l, c)| Candle {
125            timestamp: 0,
126            open: *c,
127            high: *h,
128            low: *l,
129            close: *c,
130            volume: 0.0,
131        })
132        .collect();
133
134        let outputs: Vec<_> = candles.iter().map(|c| atr.next(c)).collect();
135        let wp = atr.warmup_period();
136        // First wp-1 outputs should be None
137        assert!(outputs.iter().take(wp - 1).all(|o| o.is_none()));
138        // Output at index wp-1 should be the first Some
139        assert!(outputs[wp - 1].is_some());
140    }
141}