Skip to main content

finance_query/indicators/
supertrend.rs

1//! SuperTrend indicator.
2
3use super::{IndicatorError, Result, atr::atr};
4use serde::{Deserialize, Serialize};
5
6/// Result of SuperTrend calculation
7#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
8pub struct SuperTrendResult {
9    /// SuperTrend line
10    pub value: Vec<Option<f64>>,
11    /// Trend direction (true = up, false = down)
12    pub is_uptrend: Vec<Option<bool>>,
13}
14
15/// Calculate SuperTrend.
16///
17/// Trend-following indicator based on ATR.
18///
19/// # Arguments
20///
21/// * `highs` - High prices
22/// * `lows` - Low prices
23/// * `closes` - Close prices
24/// * `period` - ATR period
25/// * `multiplier` - ATR multiplier
26///
27/// # Example
28///
29/// ```
30/// use finance_query::indicators::supertrend;
31///
32/// let highs = vec![10.0; 20];
33/// let lows = vec![8.0; 20];
34/// let closes = vec![9.0; 20];
35/// let result = supertrend(&highs, &lows, &closes, 10, 3.0).unwrap();
36/// ```
37pub fn supertrend(
38    highs: &[f64],
39    lows: &[f64],
40    closes: &[f64],
41    period: usize,
42    multiplier: f64,
43) -> Result<SuperTrendResult> {
44    if period == 0 {
45        return Err(IndicatorError::InvalidPeriod(
46            "Period must be greater than 0".to_string(),
47        ));
48    }
49    let len = highs.len();
50    if lows.len() != len || closes.len() != len {
51        return Err(IndicatorError::InvalidPeriod(
52            "Data lengths must match".to_string(),
53        ));
54    }
55    if len < period {
56        return Err(IndicatorError::InsufficientData {
57            need: period,
58            got: len,
59        });
60    }
61
62    let atr_values = atr(highs, lows, closes, period)?;
63
64    let mut supertrend = vec![None; len];
65    let mut is_uptrend = vec![None; len];
66
67    let start_idx = period - 1;
68
69    let mut prev_final_upper = 0.0;
70    let mut prev_final_lower = 0.0;
71    let mut prev_trend = true;
72
73    for i in start_idx..len {
74        if let Some(atr_val) = atr_values[i] {
75            let hl2 = (highs[i] + lows[i]) / 2.0;
76            let basic_upper = hl2 + (multiplier * atr_val);
77            let basic_lower = hl2 - (multiplier * atr_val);
78
79            let current_close = closes[i];
80            let prev_close = if i > 0 { closes[i - 1] } else { current_close };
81
82            let final_upper = if i == start_idx
83                || basic_upper < prev_final_upper
84                || prev_close > prev_final_upper
85            {
86                basic_upper
87            } else {
88                prev_final_upper
89            };
90
91            let final_lower = if i == start_idx
92                || basic_lower > prev_final_lower
93                || prev_close < prev_final_lower
94            {
95                basic_lower
96            } else {
97                prev_final_lower
98            };
99
100            let trend = if i == start_idx {
101                true
102            } else if prev_trend && current_close <= final_lower {
103                false
104            } else if !prev_trend && current_close >= final_upper {
105                true
106            } else {
107                prev_trend
108            };
109
110            let st_val = if trend { final_lower } else { final_upper };
111
112            supertrend[i] = Some(st_val);
113            is_uptrend[i] = Some(trend);
114
115            prev_final_upper = final_upper;
116            prev_final_lower = final_lower;
117            prev_trend = trend;
118        }
119    }
120
121    Ok(SuperTrendResult {
122        value: supertrend,
123        is_uptrend,
124    })
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130
131    #[test]
132    fn test_supertrend() {
133        let highs = vec![10.0; 20];
134        let lows = vec![8.0; 20];
135        let closes = vec![9.0; 20];
136        let result = supertrend(&highs, &lows, &closes, 10, 3.0).unwrap();
137
138        assert_eq!(result.value.len(), 20);
139        assert!(result.value[8].is_none());
140        assert!(result.value[9].is_some());
141    }
142}