Skip to main content

quantwave_core/indicators/
supertrend.rs

1use crate::indicators::metadata::{IndicatorMetadata, ParamDef};
2use crate::indicators::volatility::ATR;
3use crate::traits::Next;
4
5pub const METADATA: IndicatorMetadata = IndicatorMetadata {
6    name: "SuperTrend",
7    description: "Trend-following indicator that combines ATR for volatility bands to identify the primary market direction.",
8    usage: "Use as a primary trend-following indicator and dynamic stop-loss. A SuperTrend flip from bearish to bullish (or vice versa) provides a clear, rule-based entry and exit signal.",
9    keywords: &["trend", "atr", "stop-loss", "classic", "breakout"],
10    ehlers_summary: "SuperTrend computes upper and lower ATR-based bands around the midpoint of each bar. The active line flips from upper to lower (and vice versa) only when price closes beyond the band, providing a clean directional bias and a trailing stop level in one indicator. — TradingView Community",
11    params: &[
12        ParamDef {
13            name: "period",
14            default: "10",
15            description: "ATR length",
16        },
17        ParamDef {
18            name: "multiplier",
19            default: "3.0",
20            description: "ATR multiplier",
21        },
22    ],
23    formula_source: "https://www.tradingview.com/script/7zF0a4f8-SuperTrend-by-Mobius/",
24    formula_latex: r#"
25\[
26\text{SuperTrend} = \begin{cases}
27\text{LowerBand} & \text{if trend is up} \\
28\text{UpperBand} & \text{if trend is down}
29\end{cases}
30\]
31"#,
32    gold_standard_file: "supertrend_10_3.json",
33    category: "Classic",
34};
35
36/// SuperTrend Indicator
37#[derive(Debug, Clone)]
38pub struct SuperTrend {
39    atr: ATR,
40    multiplier: f64,
41    prev_close: Option<f64>,
42    prev_upper_band: Option<f64>,
43    prev_lower_band: Option<f64>,
44    direction: i8, // 1 for up, -1 for down
45}
46
47impl SuperTrend {
48    pub fn new(period: usize, multiplier: f64) -> Self {
49        Self {
50            atr: ATR::new(period),
51            multiplier,
52            prev_close: None,
53            prev_upper_band: None,
54            prev_lower_band: None,
55            direction: 1,
56        }
57    }
58}
59
60impl Next<(f64, f64, f64)> for SuperTrend {
61    type Output = (f64, i8);
62
63    fn next(&mut self, (high, low, close): (f64, f64, f64)) -> Self::Output {
64        let atr = self.atr.next((high, low, close));
65        let mid = (high + low) / 2.0;
66
67        let basic_upper = mid + self.multiplier * atr;
68        let basic_lower = mid - self.multiplier * atr;
69
70        let upper_band = match self.prev_upper_band {
71            Some(prev_upper) => {
72                if basic_upper < prev_upper || self.prev_close.unwrap_or(0.0) > prev_upper {
73                    basic_upper
74                } else {
75                    prev_upper
76                }
77            }
78            None => basic_upper,
79        };
80
81        let lower_band = match self.prev_lower_band {
82            Some(prev_lower) => {
83                if basic_lower > prev_lower || self.prev_close.unwrap_or(0.0) < prev_lower {
84                    basic_lower
85                } else {
86                    prev_lower
87                }
88            }
89            None => basic_lower,
90        };
91
92        if self.direction == -1 && close > upper_band {
93            self.direction = 1;
94        } else if self.direction == 1 && close < lower_band {
95            self.direction = -1;
96        }
97
98        let supertrend = if self.direction == 1 {
99            lower_band
100        } else {
101            upper_band
102        };
103
104        self.prev_close = Some(close);
105        self.prev_upper_band = Some(upper_band);
106        self.prev_lower_band = Some(lower_band);
107
108        (supertrend, self.direction)
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115    use proptest::prelude::*;
116    use serde::Deserialize;
117    use std::fs;
118    use std::path::Path;
119
120    #[derive(Debug, Deserialize)]
121    struct SuperTrendCase {
122        high: Vec<f64>,
123        low: Vec<f64>,
124        close: Vec<f64>,
125        expected_st: Vec<f64>,
126        expected_dir: Vec<i8>,
127    }
128
129    #[test]
130    fn test_supertrend_gold_standard() {
131        let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap();
132        let manifest_path = Path::new(&manifest_dir);
133        let path = manifest_path.join("tests/gold_standard/supertrend_10_3.json");
134        let path = if path.exists() {
135            path
136        } else {
137            manifest_path
138                .parent()
139                .unwrap()
140                .join("tests/gold_standard/supertrend_10_3.json")
141        };
142        let content = fs::read_to_string(path).unwrap();
143        let case: SuperTrendCase = serde_json::from_str(&content).unwrap();
144
145        let mut st = SuperTrend::new(10, 3.0);
146        for i in 0..case.high.len() {
147            let (val, dir) = st.next((case.high[i], case.low[i], case.close[i]));
148            approx::assert_relative_eq!(val, case.expected_st[i], epsilon = 1e-6);
149            assert_eq!(dir, case.expected_dir[i]);
150        }
151    }
152
153    fn supertrend_batch(
154        data: Vec<(f64, f64, f64)>,
155        period: usize,
156        multiplier: f64,
157    ) -> Vec<(f64, i8)> {
158        let mut st = SuperTrend::new(period, multiplier);
159        data.into_iter().map(|x| st.next(x)).collect()
160    }
161
162    proptest! {
163        #[test]
164        fn test_supertrend_parity(input in prop::collection::vec((0.0..100.0, 0.0..100.0, 0.0..100.0), 1..100)) {
165            let mut adj_input = Vec::with_capacity(input.len());
166            for (h, l, c) in input {
167                let h_f: f64 = h;
168                let l_f: f64 = l;
169                let c_f: f64 = c;
170                let high = h_f.max(l_f).max(c_f);
171                let low = l_f.min(h_f).min(c_f);
172                adj_input.push((high, low, c_f));
173            }
174
175            let period = 10;
176            let multiplier = 3.0;
177            let mut st = SuperTrend::new(period, multiplier);
178            let mut streaming_results = Vec::with_capacity(adj_input.len());
179            for &val in &adj_input {
180                streaming_results.push(st.next(val));
181            }
182
183            let batch_results = supertrend_batch(adj_input, period, multiplier);
184
185            for (s, b) in streaming_results.iter().zip(batch_results.iter()) {
186                approx::assert_relative_eq!(s.0, b.0, epsilon = 1e-6);
187                assert_eq!(s.1, b.1);
188            }
189        }
190    }
191}