quantwave_core/indicators/
atr_ts.rs1use crate::indicators::metadata::{IndicatorMetadata, ParamDef};
2use crate::indicators::volatility::ATR;
3use crate::traits::Next;
4
5#[derive(Debug, Clone)]
10pub struct ATRTrailingStop {
11 atr: ATR,
12 multiplier: f64,
13 prev_stop: Option<f64>,
14 direction: i8, }
16
17impl ATRTrailingStop {
18 pub fn new(period: usize, multiplier: f64) -> Self {
19 Self {
20 atr: ATR::new(period),
21 multiplier,
22 prev_stop: None,
23 direction: 1,
24 }
25 }
26}
27
28impl Next<(f64, f64, f64)> for ATRTrailingStop {
29 type Output = (f64, i8); fn next(&mut self, (high, low, close): (f64, f64, f64)) -> Self::Output {
32 let current_atr = self.atr.next((high, low, close));
33 let long_stop = close - self.multiplier * current_atr;
34 let short_stop = close + self.multiplier * current_atr;
35
36 let prev_stop = match self.prev_stop {
37 Some(stop) => stop,
38 None => {
39 self.prev_stop = Some(long_stop);
40 self.direction = 1;
41 return (long_stop, 1);
42 }
43 };
44
45 if self.direction == 1 {
46 if close < prev_stop {
47 self.direction = -1;
48 self.prev_stop = Some(short_stop);
49 (short_stop, -1)
50 } else {
51 let new_stop = prev_stop.max(long_stop);
52 self.prev_stop = Some(new_stop);
53 (new_stop, 1)
54 }
55 } else {
56 if close > prev_stop {
57 self.direction = 1;
58 self.prev_stop = Some(long_stop);
59 (long_stop, 1)
60 } else {
61 let new_stop = prev_stop.min(short_stop);
62 self.prev_stop = Some(new_stop);
63 (new_stop, -1)
64 }
65 }
66 }
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72 use proptest::prelude::*;
73 use serde::Deserialize;
74 use std::fs;
75 use std::path::Path;
76
77 #[derive(Debug, Deserialize)]
78 struct ATRTSCase {
79 high: Vec<f64>,
80 low: Vec<f64>,
81 close: Vec<f64>,
82 expected_stop: Vec<f64>,
83 expected_dir: Vec<i8>,
84 }
85
86 #[test]
87 fn test_atr_ts_gold_standard() {
88 let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap();
89 let manifest_path = Path::new(&manifest_dir);
90 let path = manifest_path.join("tests/gold_standard/atr_ts_14_25.json");
91 let path = if path.exists() {
92 path
93 } else {
94 manifest_path
95 .parent()
96 .unwrap()
97 .join("tests/gold_standard/atr_ts_14_25.json")
98 };
99 let content = fs::read_to_string(path).unwrap();
100 let case: ATRTSCase = serde_json::from_str(&content).unwrap();
101
102 let mut atr_ts = ATRTrailingStop::new(14, 2.5);
103 for i in 0..case.high.len() {
104 let (stop, dir) = atr_ts.next((case.high[i], case.low[i], case.close[i]));
105 approx::assert_relative_eq!(stop, case.expected_stop[i], epsilon = 1e-6);
106 assert_eq!(dir, case.expected_dir[i]);
107 }
108 }
109
110 fn atr_ts_batch(data: Vec<(f64, f64, f64)>, period: usize, multiplier: f64) -> Vec<(f64, i8)> {
111 let mut atr_ts = ATRTrailingStop::new(period, multiplier);
112 data.into_iter().map(|x| atr_ts.next(x)).collect()
113 }
114
115 proptest! {
116 #[test]
117 fn test_atr_ts_parity(input in prop::collection::vec((0.0..100.0, 0.0..100.0, 0.0..100.0), 1..100)) {
118 let mut adj_input = Vec::with_capacity(input.len());
119 for (h, l, c) in input {
120 let h_f: f64 = h;
121 let l_f: f64 = l;
122 let c_f: f64 = c;
123 let high = h_f.max(l_f).max(c_f);
124 let low = l_f.min(h_f).min(c_f);
125 adj_input.push((high, low, c_f));
126 }
127
128 let period = 14;
129 let multiplier = 2.5;
130 let mut atr_ts = ATRTrailingStop::new(period, multiplier);
131 let mut streaming_results = Vec::with_capacity(adj_input.len());
132 for &val in &adj_input {
133 streaming_results.push(atr_ts.next(val));
134 }
135
136 let batch_results = atr_ts_batch(adj_input, period, multiplier);
137
138 for (s, b) in streaming_results.iter().zip(batch_results.iter()) {
139 approx::assert_relative_eq!(s.0, b.0, epsilon = 1e-6);
140 assert_eq!(s.1, b.1);
141 }
142 }
143 }
144
145 #[test]
146 fn test_atr_ts_basic() {
147 let mut atr_ts = ATRTrailingStop::new(14, 2.5);
148
149 let (stop1, dir1) = atr_ts.next((10.0, 8.0, 9.0));
150 assert!(stop1 < 9.0);
151 assert_eq!(dir1, 1);
152 }
153}
154
155pub const ATR_TS_METADATA: IndicatorMetadata = IndicatorMetadata {
156 name: "ATR Trailing Stop",
157 description: "A trailing stop based on Average True Range to keep trades in a trend.",
158 params: &[
159 ParamDef {
160 name: "period",
161 default: "10",
162 description: "ATR period",
163 },
164 ParamDef {
165 name: "multiplier",
166 default: "3.0",
167 description: "ATR Multiplier",
168 },
169 ],
170 formula_source: "https://www.tradingview.com/support/solutions/43000589105-average-true-range-atr/",
171 formula_latex: r#"
172\[
173Stop = P_{high} - (Multiplier \times ATR)
174\]
175"#,
176 gold_standard_file: "atr_ts.json",
177 category: "Classic",
178};