quantwave_core/indicators/
supertrend.rs1use crate::indicators::metadata::{IndicatorMetadata, ParamDef};
2use crate::indicators::volatility::ATR;
3use crate::traits::Next;
4
5pub const METADATA: IndicatorMetadata = IndicatorMetadata {
6 name: "SuperTrend",
7 description: "Trend-following indicator that combines ATR for volatility bands to identify the primary market direction.",
8 usage: "Use as a primary trend-following indicator and dynamic stop-loss. A SuperTrend flip from bearish to bullish (or vice versa) provides a clear, rule-based entry and exit signal.",
9 keywords: &["trend", "atr", "stop-loss", "classic", "breakout"],
10 ehlers_summary: "SuperTrend computes upper and lower ATR-based bands around the midpoint of each bar. The active line flips from upper to lower (and vice versa) only when price closes beyond the band, providing a clean directional bias and a trailing stop level in one indicator. — TradingView Community",
11 params: &[
12 ParamDef {
13 name: "period",
14 default: "10",
15 description: "ATR length",
16 },
17 ParamDef {
18 name: "multiplier",
19 default: "3.0",
20 description: "ATR multiplier",
21 },
22 ],
23 formula_source: "https://www.tradingview.com/script/7zF0a4f8-SuperTrend-by-Mobius/",
24 formula_latex: r#"
25\[
26\text{SuperTrend} = \begin{cases}
27\text{LowerBand} & \text{if trend is up} \\
28\text{UpperBand} & \text{if trend is down}
29\end{cases}
30\]
31"#,
32 gold_standard_file: "supertrend_10_3.json",
33 category: "Classic",
34};
35
36#[derive(Debug, Clone)]
38pub struct SuperTrend {
39 atr: ATR,
40 multiplier: f64,
41 prev_close: Option<f64>,
42 prev_upper_band: Option<f64>,
43 prev_lower_band: Option<f64>,
44 direction: i8, }
46
47impl SuperTrend {
48 pub fn new(period: usize, multiplier: f64) -> Self {
49 Self {
50 atr: ATR::new(period),
51 multiplier,
52 prev_close: None,
53 prev_upper_band: None,
54 prev_lower_band: None,
55 direction: 1,
56 }
57 }
58}
59
60impl Next<(f64, f64, f64)> for SuperTrend {
61 type Output = (f64, i8);
62
63 fn next(&mut self, (high, low, close): (f64, f64, f64)) -> Self::Output {
64 let atr = self.atr.next((high, low, close));
65 let mid = (high + low) / 2.0;
66
67 let basic_upper = mid + self.multiplier * atr;
68 let basic_lower = mid - self.multiplier * atr;
69
70 let upper_band = match self.prev_upper_band {
71 Some(prev_upper) => {
72 if basic_upper < prev_upper || self.prev_close.unwrap_or(0.0) > prev_upper {
73 basic_upper
74 } else {
75 prev_upper
76 }
77 }
78 None => basic_upper,
79 };
80
81 let lower_band = match self.prev_lower_band {
82 Some(prev_lower) => {
83 if basic_lower > prev_lower || self.prev_close.unwrap_or(0.0) < prev_lower {
84 basic_lower
85 } else {
86 prev_lower
87 }
88 }
89 None => basic_lower,
90 };
91
92 if self.direction == -1 && close > upper_band {
93 self.direction = 1;
94 } else if self.direction == 1 && close < lower_band {
95 self.direction = -1;
96 }
97
98 let supertrend = if self.direction == 1 {
99 lower_band
100 } else {
101 upper_band
102 };
103
104 self.prev_close = Some(close);
105 self.prev_upper_band = Some(upper_band);
106 self.prev_lower_band = Some(lower_band);
107
108 (supertrend, self.direction)
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115 use proptest::prelude::*;
116 use serde::Deserialize;
117 use std::fs;
118 use std::path::Path;
119
120 #[derive(Debug, Deserialize)]
121 struct SuperTrendCase {
122 high: Vec<f64>,
123 low: Vec<f64>,
124 close: Vec<f64>,
125 expected_st: Vec<f64>,
126 expected_dir: Vec<i8>,
127 }
128
129 #[test]
130 fn test_supertrend_gold_standard() {
131 let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap();
132 let manifest_path = Path::new(&manifest_dir);
133 let path = manifest_path.join("tests/gold_standard/supertrend_10_3.json");
134 let path = if path.exists() {
135 path
136 } else {
137 manifest_path
138 .parent()
139 .unwrap()
140 .join("tests/gold_standard/supertrend_10_3.json")
141 };
142 let content = fs::read_to_string(path).unwrap();
143 let case: SuperTrendCase = serde_json::from_str(&content).unwrap();
144
145 let mut st = SuperTrend::new(10, 3.0);
146 for i in 0..case.high.len() {
147 let (val, dir) = st.next((case.high[i], case.low[i], case.close[i]));
148 approx::assert_relative_eq!(val, case.expected_st[i], epsilon = 1e-6);
149 assert_eq!(dir, case.expected_dir[i]);
150 }
151 }
152
153 fn supertrend_batch(
154 data: Vec<(f64, f64, f64)>,
155 period: usize,
156 multiplier: f64,
157 ) -> Vec<(f64, i8)> {
158 let mut st = SuperTrend::new(period, multiplier);
159 data.into_iter().map(|x| st.next(x)).collect()
160 }
161
162 proptest! {
163 #[test]
164 fn test_supertrend_parity(input in prop::collection::vec((0.0..100.0, 0.0..100.0, 0.0..100.0), 1..100)) {
165 let mut adj_input = Vec::with_capacity(input.len());
166 for (h, l, c) in input {
167 let h_f: f64 = h;
168 let l_f: f64 = l;
169 let c_f: f64 = c;
170 let high = h_f.max(l_f).max(c_f);
171 let low = l_f.min(h_f).min(c_f);
172 adj_input.push((high, low, c_f));
173 }
174
175 let period = 10;
176 let multiplier = 3.0;
177 let mut st = SuperTrend::new(period, multiplier);
178 let mut streaming_results = Vec::with_capacity(adj_input.len());
179 for &val in &adj_input {
180 streaming_results.push(st.next(val));
181 }
182
183 let batch_results = supertrend_batch(adj_input, period, multiplier);
184
185 for (s, b) in streaming_results.iter().zip(batch_results.iter()) {
186 approx::assert_relative_eq!(s.0, b.0, epsilon = 1e-6);
187 assert_eq!(s.1, b.1);
188 }
189 }
190 }
191}