Skip to main content

finance_query/indicators/
supertrend.rs

1//! SuperTrend indicator.
2
3use super::{IndicatorError, Result, atr::atr};
4use serde::{Deserialize, Serialize};
5
6/// Result of SuperTrend calculation
7#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
8pub struct SuperTrendResult {
9    /// SuperTrend line
10    pub value: Vec<Option<f64>>,
11    /// Trend direction (true = up, false = down)
12    pub is_uptrend: Vec<Option<bool>>,
13}
14
15/// Calculate SuperTrend.
16///
17/// Trend-following indicator based on ATR.
18///
19/// # Arguments
20///
21/// * `highs` - High prices
22/// * `lows` - Low prices
23/// * `closes` - Close prices
24/// * `period` - ATR period
25/// * `multiplier` - ATR multiplier
26///
27/// # Example
28///
29/// ```
30/// use finance_query::indicators::supertrend;
31///
32/// let highs = vec![10.0; 20];
33/// let lows = vec![8.0; 20];
34/// let closes = vec![9.0; 20];
35/// let result = supertrend(&highs, &lows, &closes, 10, 3.0).unwrap();
36/// ```
37pub fn supertrend(
38    highs: &[f64],
39    lows: &[f64],
40    closes: &[f64],
41    period: usize,
42    multiplier: f64,
43) -> Result<SuperTrendResult> {
44    if period == 0 {
45        return Err(IndicatorError::InvalidPeriod(
46            "Period must be greater than 0".to_string(),
47        ));
48    }
49    let len = highs.len();
50    if lows.len() != len || closes.len() != len {
51        return Err(IndicatorError::InvalidPeriod(
52            "Data lengths must match".to_string(),
53        ));
54    }
55    if len < period {
56        return Err(IndicatorError::InsufficientData {
57            need: period,
58            got: len,
59        });
60    }
61
62    let atr_values = atr(highs, lows, closes, period)?;
63
64    let mut supertrend = vec![None; len];
65    let mut is_uptrend = vec![None; len];
66
67    let start_idx = period - 1;
68
69    let mut prev_final_upper = 0.0;
70    let mut prev_final_lower = 0.0;
71    let mut prev_trend = true;
72
73    for i in start_idx..len {
74        if let Some(atr_val) = atr_values[i] {
75            let hl2 = (highs[i] + lows[i]) / 2.0;
76            let basic_upper = hl2 + (multiplier * atr_val);
77            let basic_lower = hl2 - (multiplier * atr_val);
78
79            let current_close = closes[i];
80            let prev_close = if i > 0 { closes[i - 1] } else { current_close };
81
82            let final_upper = if i == start_idx
83                || basic_upper < prev_final_upper
84                || prev_close > prev_final_upper
85            {
86                basic_upper
87            } else {
88                prev_final_upper
89            };
90
91            let final_lower = if i == start_idx
92                || basic_lower > prev_final_lower
93                || prev_close < prev_final_lower
94            {
95                basic_lower
96            } else {
97                prev_final_lower
98            };
99
100            let trend = if i == start_idx {
101                true
102            } else if prev_trend && current_close <= final_lower {
103                false
104            } else if !prev_trend && current_close >= final_upper {
105                true
106            } else {
107                prev_trend
108            };
109
110            let st_val = if trend { final_lower } else { final_upper };
111
112            supertrend[i] = Some(st_val);
113            is_uptrend[i] = Some(trend);
114
115            prev_final_upper = final_upper;
116            prev_final_lower = final_lower;
117            prev_trend = trend;
118        }
119    }
120
121    Ok(SuperTrendResult {
122        value: supertrend,
123        is_uptrend,
124    })
125}
126
127/// Internal variant accepting pre-computed dense ATR values (avoids redundant ATR computation).
128/// `atr_dense[k]` corresponds to original index `k + atr_period - 1`.
129pub(crate) fn supertrend_with_atr_dense(
130    highs: &[f64],
131    lows: &[f64],
132    closes: &[f64],
133    atr_dense: &[f64],
134    atr_period: usize,
135    multiplier: f64,
136) -> Result<SuperTrendResult> {
137    let len = highs.len();
138    if lows.len() != len || closes.len() != len {
139        return Err(IndicatorError::InvalidPeriod(
140            "Data lengths must match".to_string(),
141        ));
142    }
143    let start_idx = atr_period - 1;
144    let atr_off = start_idx;
145    let mut supertrend = vec![None; len];
146    let mut is_uptrend = vec![None; len];
147    let mut prev_final_upper = 0.0;
148    let mut prev_final_lower = 0.0;
149    let mut prev_trend = true;
150    for i in start_idx..len {
151        let atr_val = atr_dense[i - atr_off];
152        let hl2 = (highs[i] + lows[i]) / 2.0;
153        let basic_upper = hl2 + multiplier * atr_val;
154        let basic_lower = hl2 - multiplier * atr_val;
155        let current_close = closes[i];
156        let prev_close = if i > 0 { closes[i - 1] } else { current_close };
157        let final_upper =
158            if i == start_idx || basic_upper < prev_final_upper || prev_close > prev_final_upper {
159                basic_upper
160            } else {
161                prev_final_upper
162            };
163        let final_lower =
164            if i == start_idx || basic_lower > prev_final_lower || prev_close < prev_final_lower {
165                basic_lower
166            } else {
167                prev_final_lower
168            };
169        let trend = if i == start_idx {
170            true
171        } else if prev_trend && current_close <= final_lower {
172            false
173        } else if !prev_trend && current_close >= final_upper {
174            true
175        } else {
176            prev_trend
177        };
178        supertrend[i] = Some(if trend { final_lower } else { final_upper });
179        is_uptrend[i] = Some(trend);
180        prev_final_upper = final_upper;
181        prev_final_lower = final_lower;
182        prev_trend = trend;
183    }
184    Ok(SuperTrendResult {
185        value: supertrend,
186        is_uptrend,
187    })
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193
194    #[test]
195    fn test_supertrend() {
196        let highs = vec![10.0; 20];
197        let lows = vec![8.0; 20];
198        let closes = vec![9.0; 20];
199        let result = supertrend(&highs, &lows, &closes, 10, 3.0).unwrap();
200
201        assert_eq!(result.value.len(), 20);
202        assert!(result.value[8].is_none());
203        assert!(result.value[9].is_some());
204    }
205}