quantwave_core/indicators/
supertrend.rs1use crate::indicators::metadata::{IndicatorMetadata, ParamDef};
2use crate::indicators::volatility::ATR;
3use crate::traits::Next;
4
5pub const METADATA: IndicatorMetadata = IndicatorMetadata {
6 name: "SuperTrend",
7 description: "Trend-following indicator that combines ATR for volatility bands to identify the primary market direction.",
8 params: &[
9 ParamDef {
10 name: "period",
11 default: "10",
12 description: "ATR length",
13 },
14 ParamDef {
15 name: "multiplier",
16 default: "3.0",
17 description: "ATR multiplier",
18 },
19 ],
20 formula_source: "https://www.tradingview.com/script/7zF0a4f8-SuperTrend-by-Mobius/",
21 formula_latex: r#"
22\[
23\text{SuperTrend} = \begin{cases}
24\text{LowerBand} & \text{if trend is up} \\
25\text{UpperBand} & \text{if trend is down}
26\end{cases}
27\]
28"#,
29 gold_standard_file: "supertrend_10_3.json",
30 category: "Classic",
31};
32
33#[derive(Debug, Clone)]
35pub struct SuperTrend {
36 atr: ATR,
37 multiplier: f64,
38 prev_close: Option<f64>,
39 prev_upper_band: Option<f64>,
40 prev_lower_band: Option<f64>,
41 direction: i8, }
43
44impl SuperTrend {
45 pub fn new(period: usize, multiplier: f64) -> Self {
46 Self {
47 atr: ATR::new(period),
48 multiplier,
49 prev_close: None,
50 prev_upper_band: None,
51 prev_lower_band: None,
52 direction: 1,
53 }
54 }
55}
56
57impl Next<(f64, f64, f64)> for SuperTrend {
58 type Output = (f64, i8);
59
60 fn next(&mut self, (high, low, close): (f64, f64, f64)) -> Self::Output {
61 let atr = self.atr.next((high, low, close));
62 let mid = (high + low) / 2.0;
63
64 let basic_upper = mid + self.multiplier * atr;
65 let basic_lower = mid - self.multiplier * atr;
66
67 let upper_band = match self.prev_upper_band {
68 Some(prev_upper) => {
69 if basic_upper < prev_upper || self.prev_close.unwrap_or(0.0) > prev_upper {
70 basic_upper
71 } else {
72 prev_upper
73 }
74 }
75 None => basic_upper,
76 };
77
78 let lower_band = match self.prev_lower_band {
79 Some(prev_lower) => {
80 if basic_lower > prev_lower || self.prev_close.unwrap_or(0.0) < prev_lower {
81 basic_lower
82 } else {
83 prev_lower
84 }
85 }
86 None => basic_lower,
87 };
88
89 if self.direction == -1 && close > upper_band {
90 self.direction = 1;
91 } else if self.direction == 1 && close < lower_band {
92 self.direction = -1;
93 }
94
95 let supertrend = if self.direction == 1 {
96 lower_band
97 } else {
98 upper_band
99 };
100
101 self.prev_close = Some(close);
102 self.prev_upper_band = Some(upper_band);
103 self.prev_lower_band = Some(lower_band);
104
105 (supertrend, self.direction)
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112 use proptest::prelude::*;
113 use serde::Deserialize;
114 use std::fs;
115 use std::path::Path;
116
117 #[derive(Debug, Deserialize)]
118 struct SuperTrendCase {
119 high: Vec<f64>,
120 low: Vec<f64>,
121 close: Vec<f64>,
122 expected_st: Vec<f64>,
123 expected_dir: Vec<i8>,
124 }
125
126 #[test]
127 fn test_supertrend_gold_standard() {
128 let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap();
129 let manifest_path = Path::new(&manifest_dir);
130 let path = manifest_path.join("tests/gold_standard/supertrend_10_3.json");
131 let path = if path.exists() {
132 path
133 } else {
134 manifest_path
135 .parent()
136 .unwrap()
137 .join("tests/gold_standard/supertrend_10_3.json")
138 };
139 let content = fs::read_to_string(path).unwrap();
140 let case: SuperTrendCase = serde_json::from_str(&content).unwrap();
141
142 let mut st = SuperTrend::new(10, 3.0);
143 for i in 0..case.high.len() {
144 let (val, dir) = st.next((case.high[i], case.low[i], case.close[i]));
145 approx::assert_relative_eq!(val, case.expected_st[i], epsilon = 1e-6);
146 assert_eq!(dir, case.expected_dir[i]);
147 }
148 }
149
150 fn supertrend_batch(
151 data: Vec<(f64, f64, f64)>,
152 period: usize,
153 multiplier: f64,
154 ) -> Vec<(f64, i8)> {
155 let mut st = SuperTrend::new(period, multiplier);
156 data.into_iter().map(|x| st.next(x)).collect()
157 }
158
159 proptest! {
160 #[test]
161 fn test_supertrend_parity(input in prop::collection::vec((0.0..100.0, 0.0..100.0, 0.0..100.0), 1..100)) {
162 let mut adj_input = Vec::with_capacity(input.len());
163 for (h, l, c) in input {
164 let h_f: f64 = h;
165 let l_f: f64 = l;
166 let c_f: f64 = c;
167 let high = h_f.max(l_f).max(c_f);
168 let low = l_f.min(h_f).min(c_f);
169 adj_input.push((high, low, c_f));
170 }
171
172 let period = 10;
173 let multiplier = 3.0;
174 let mut st = SuperTrend::new(period, multiplier);
175 let mut streaming_results = Vec::with_capacity(adj_input.len());
176 for &val in &adj_input {
177 streaming_results.push(st.next(val));
178 }
179
180 let batch_results = supertrend_batch(adj_input, period, multiplier);
181
182 for (s, b) in streaming_results.iter().zip(batch_results.iter()) {
183 approx::assert_relative_eq!(s.0, b.0, epsilon = 1e-6);
184 assert_eq!(s.1, b.1);
185 }
186 }
187 }
188}