Skip to main content

quantwave_core/indicators/
supertrend.rs

1use crate::indicators::metadata::{IndicatorMetadata, ParamDef};
2use crate::indicators::volatility::ATR;
3use crate::traits::Next;
4
5pub const METADATA: IndicatorMetadata = IndicatorMetadata {
6    name: "SuperTrend",
7    description: "Trend-following indicator that combines ATR for volatility bands to identify the primary market direction.",
8    params: &[
9        ParamDef {
10            name: "period",
11            default: "10",
12            description: "ATR length",
13        },
14        ParamDef {
15            name: "multiplier",
16            default: "3.0",
17            description: "ATR multiplier",
18        },
19    ],
20    formula_source: "https://www.tradingview.com/script/7zF0a4f8-SuperTrend-by-Mobius/",
21    formula_latex: r#"
22\[
23\text{SuperTrend} = \begin{cases}
24\text{LowerBand} & \text{if trend is up} \\
25\text{UpperBand} & \text{if trend is down}
26\end{cases}
27\]
28"#,
29    gold_standard_file: "supertrend_10_3.json",
30    category: "Classic",
31};
32
33/// SuperTrend Indicator
34#[derive(Debug, Clone)]
35pub struct SuperTrend {
36    atr: ATR,
37    multiplier: f64,
38    prev_close: Option<f64>,
39    prev_upper_band: Option<f64>,
40    prev_lower_band: Option<f64>,
41    direction: i8, // 1 for up, -1 for down
42}
43
44impl SuperTrend {
45    pub fn new(period: usize, multiplier: f64) -> Self {
46        Self {
47            atr: ATR::new(period),
48            multiplier,
49            prev_close: None,
50            prev_upper_band: None,
51            prev_lower_band: None,
52            direction: 1,
53        }
54    }
55}
56
57impl Next<(f64, f64, f64)> for SuperTrend {
58    type Output = (f64, i8);
59
60    fn next(&mut self, (high, low, close): (f64, f64, f64)) -> Self::Output {
61        let atr = self.atr.next((high, low, close));
62        let mid = (high + low) / 2.0;
63
64        let basic_upper = mid + self.multiplier * atr;
65        let basic_lower = mid - self.multiplier * atr;
66
67        let upper_band = match self.prev_upper_band {
68            Some(prev_upper) => {
69                if basic_upper < prev_upper || self.prev_close.unwrap_or(0.0) > prev_upper {
70                    basic_upper
71                } else {
72                    prev_upper
73                }
74            }
75            None => basic_upper,
76        };
77
78        let lower_band = match self.prev_lower_band {
79            Some(prev_lower) => {
80                if basic_lower > prev_lower || self.prev_close.unwrap_or(0.0) < prev_lower {
81                    basic_lower
82                } else {
83                    prev_lower
84                }
85            }
86            None => basic_lower,
87        };
88
89        if self.direction == -1 && close > upper_band {
90            self.direction = 1;
91        } else if self.direction == 1 && close < lower_band {
92            self.direction = -1;
93        }
94
95        let supertrend = if self.direction == 1 {
96            lower_band
97        } else {
98            upper_band
99        };
100
101        self.prev_close = Some(close);
102        self.prev_upper_band = Some(upper_band);
103        self.prev_lower_band = Some(lower_band);
104
105        (supertrend, self.direction)
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112    use proptest::prelude::*;
113    use serde::Deserialize;
114    use std::fs;
115    use std::path::Path;
116
117    #[derive(Debug, Deserialize)]
118    struct SuperTrendCase {
119        high: Vec<f64>,
120        low: Vec<f64>,
121        close: Vec<f64>,
122        expected_st: Vec<f64>,
123        expected_dir: Vec<i8>,
124    }
125
126    #[test]
127    fn test_supertrend_gold_standard() {
128        let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap();
129        let manifest_path = Path::new(&manifest_dir);
130        let path = manifest_path.join("tests/gold_standard/supertrend_10_3.json");
131        let path = if path.exists() {
132            path
133        } else {
134            manifest_path
135                .parent()
136                .unwrap()
137                .join("tests/gold_standard/supertrend_10_3.json")
138        };
139        let content = fs::read_to_string(path).unwrap();
140        let case: SuperTrendCase = serde_json::from_str(&content).unwrap();
141
142        let mut st = SuperTrend::new(10, 3.0);
143        for i in 0..case.high.len() {
144            let (val, dir) = st.next((case.high[i], case.low[i], case.close[i]));
145            approx::assert_relative_eq!(val, case.expected_st[i], epsilon = 1e-6);
146            assert_eq!(dir, case.expected_dir[i]);
147        }
148    }
149
150    fn supertrend_batch(
151        data: Vec<(f64, f64, f64)>,
152        period: usize,
153        multiplier: f64,
154    ) -> Vec<(f64, i8)> {
155        let mut st = SuperTrend::new(period, multiplier);
156        data.into_iter().map(|x| st.next(x)).collect()
157    }
158
159    proptest! {
160        #[test]
161        fn test_supertrend_parity(input in prop::collection::vec((0.0..100.0, 0.0..100.0, 0.0..100.0), 1..100)) {
162            let mut adj_input = Vec::with_capacity(input.len());
163            for (h, l, c) in input {
164                let h_f: f64 = h;
165                let l_f: f64 = l;
166                let c_f: f64 = c;
167                let high = h_f.max(l_f).max(c_f);
168                let low = l_f.min(h_f).min(c_f);
169                adj_input.push((high, low, c_f));
170            }
171
172            let period = 10;
173            let multiplier = 3.0;
174            let mut st = SuperTrend::new(period, multiplier);
175            let mut streaming_results = Vec::with_capacity(adj_input.len());
176            for &val in &adj_input {
177                streaming_results.push(st.next(val));
178            }
179
180            let batch_results = supertrend_batch(adj_input, period, multiplier);
181
182            for (s, b) in streaming_results.iter().zip(batch_results.iter()) {
183                approx::assert_relative_eq!(s.0, b.0, epsilon = 1e-6);
184                assert_eq!(s.1, b.1);
185            }
186        }
187    }
188}