Skip to main content

quantwave_core/indicators/
atr_ts.rs

1use crate::indicators::metadata::{IndicatorMetadata, ParamDef};
2use crate::indicators::volatility::ATR;
3use crate::traits::Next;
4
5/// ATR Trailing Stop
6/// A volatility-based trailing stop.
7/// Long Stop = Close - Multiplier * ATR
8/// Short Stop = Close + Multiplier * ATR
9#[derive(Debug, Clone)]
10pub struct ATRTrailingStop {
11    atr: ATR,
12    multiplier: f64,
13    prev_stop: Option<f64>,
14    direction: i8, // 1 for Long, -1 for Short
15}
16
17impl ATRTrailingStop {
18    pub fn new(period: usize, multiplier: f64) -> Self {
19        Self {
20            atr: ATR::new(period),
21            multiplier,
22            prev_stop: None,
23            direction: 1,
24        }
25    }
26}
27
28impl Next<(f64, f64, f64)> for ATRTrailingStop {
29    type Output = (f64, i8); // (Stop Level, Direction)
30
31    fn next(&mut self, (high, low, close): (f64, f64, f64)) -> Self::Output {
32        let current_atr = self.atr.next((high, low, close));
33        let long_stop = close - self.multiplier * current_atr;
34        let short_stop = close + self.multiplier * current_atr;
35
36        let prev_stop = match self.prev_stop {
37            Some(stop) => stop,
38            None => {
39                self.prev_stop = Some(long_stop);
40                self.direction = 1;
41                return (long_stop, 1);
42            }
43        };
44
45        if self.direction == 1 {
46            if close < prev_stop {
47                self.direction = -1;
48                self.prev_stop = Some(short_stop);
49                (short_stop, -1)
50            } else {
51                let new_stop = prev_stop.max(long_stop);
52                self.prev_stop = Some(new_stop);
53                (new_stop, 1)
54            }
55        } else {
56            if close > prev_stop {
57                self.direction = 1;
58                self.prev_stop = Some(long_stop);
59                (long_stop, 1)
60            } else {
61                let new_stop = prev_stop.min(short_stop);
62                self.prev_stop = Some(new_stop);
63                (new_stop, -1)
64            }
65        }
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use super::*;
72    use proptest::prelude::*;
73    use serde::Deserialize;
74    use std::fs;
75    use std::path::Path;
76
77    #[derive(Debug, Deserialize)]
78    struct ATRTSCase {
79        high: Vec<f64>,
80        low: Vec<f64>,
81        close: Vec<f64>,
82        expected_stop: Vec<f64>,
83        expected_dir: Vec<i8>,
84    }
85
86    #[test]
87    fn test_atr_ts_gold_standard() {
88        let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap();
89        let manifest_path = Path::new(&manifest_dir);
90        let path = manifest_path.join("tests/gold_standard/atr_ts_14_25.json");
91        let path = if path.exists() {
92            path
93        } else {
94            manifest_path
95                .parent()
96                .unwrap()
97                .join("tests/gold_standard/atr_ts_14_25.json")
98        };
99        let content = fs::read_to_string(path).unwrap();
100        let case: ATRTSCase = serde_json::from_str(&content).unwrap();
101
102        let mut atr_ts = ATRTrailingStop::new(14, 2.5);
103        for i in 0..case.high.len() {
104            let (stop, dir) = atr_ts.next((case.high[i], case.low[i], case.close[i]));
105            approx::assert_relative_eq!(stop, case.expected_stop[i], epsilon = 1e-6);
106            assert_eq!(dir, case.expected_dir[i]);
107        }
108    }
109
110    fn atr_ts_batch(data: Vec<(f64, f64, f64)>, period: usize, multiplier: f64) -> Vec<(f64, i8)> {
111        let mut atr_ts = ATRTrailingStop::new(period, multiplier);
112        data.into_iter().map(|x| atr_ts.next(x)).collect()
113    }
114
115    proptest! {
116        #[test]
117        fn test_atr_ts_parity(input in prop::collection::vec((0.0..100.0, 0.0..100.0, 0.0..100.0), 1..100)) {
118            let mut adj_input = Vec::with_capacity(input.len());
119            for (h, l, c) in input {
120                let h_f: f64 = h;
121                let l_f: f64 = l;
122                let c_f: f64 = c;
123                let high = h_f.max(l_f).max(c_f);
124                let low = l_f.min(h_f).min(c_f);
125                adj_input.push((high, low, c_f));
126            }
127
128            let period = 14;
129            let multiplier = 2.5;
130            let mut atr_ts = ATRTrailingStop::new(period, multiplier);
131            let mut streaming_results = Vec::with_capacity(adj_input.len());
132            for &val in &adj_input {
133                streaming_results.push(atr_ts.next(val));
134            }
135
136            let batch_results = atr_ts_batch(adj_input, period, multiplier);
137
138            for (s, b) in streaming_results.iter().zip(batch_results.iter()) {
139                approx::assert_relative_eq!(s.0, b.0, epsilon = 1e-6);
140                assert_eq!(s.1, b.1);
141            }
142        }
143    }
144
145    #[test]
146    fn test_atr_ts_basic() {
147        let mut atr_ts = ATRTrailingStop::new(14, 2.5);
148
149        let (stop1, dir1) = atr_ts.next((10.0, 8.0, 9.0));
150        assert!(stop1 < 9.0);
151        assert_eq!(dir1, 1);
152    }
153}
154
155pub const ATR_TS_METADATA: IndicatorMetadata = IndicatorMetadata {
156    name: "ATR Trailing Stop",
157    description: "A trailing stop based on Average True Range to keep trades in a trend.",
158    params: &[
159        ParamDef {
160            name: "period",
161            default: "10",
162            description: "ATR period",
163        },
164        ParamDef {
165            name: "multiplier",
166            default: "3.0",
167            description: "ATR Multiplier",
168        },
169    ],
170    formula_source: "https://www.tradingview.com/support/solutions/43000589105-average-true-range-atr/",
171    formula_latex: r#"
172\[
173Stop = P_{high} - (Multiplier \times ATR)
174\]
175"#,
176    gold_standard_file: "atr_ts.json",
177    category: "Classic",
178};