quant_indicators/
supertrend.rs1use quant_primitives::Candle;
32use rust_decimal::Decimal;
33
34use crate::error::IndicatorError;
35use crate::indicator::Indicator;
36use crate::series::Series;
37use crate::true_range;
38
39#[derive(Debug, Clone)]
60pub struct Supertrend {
61 atr_period: usize,
62 multiplier: Decimal,
63 name: String,
64}
65
66impl Supertrend {
67 pub fn new(atr_period: usize, multiplier: Decimal) -> Result<Self, IndicatorError> {
78 if atr_period == 0 {
79 return Err(IndicatorError::InvalidParameter {
80 message: "Supertrend ATR period must be > 0".to_string(),
81 });
82 }
83 if multiplier <= Decimal::ZERO {
84 return Err(IndicatorError::InvalidParameter {
85 message: "Supertrend multiplier must be > 0".to_string(),
86 });
87 }
88 Ok(Self {
89 atr_period,
90 multiplier,
91 name: format!("Supertrend({},{})", atr_period, multiplier),
92 })
93 }
94
95 #[must_use]
97 pub fn atr_period(&self) -> usize {
98 self.atr_period
99 }
100
101 #[must_use]
102 pub fn multiplier(&self) -> Decimal {
103 self.multiplier
104 }
105}
106
107fn smoothed_atr(candles: &[Candle], period: usize) -> Vec<Decimal> {
109 let period_dec = Decimal::from(period as u64);
110 let mut true_ranges = Vec::with_capacity(candles.len());
111 true_ranges.push(candles[0].high() - candles[0].low());
112 for i in 1..candles.len() {
113 true_ranges.push(true_range(&candles[i], candles[i - 1].close()));
114 }
115
116 let initial_sum: Decimal = true_ranges[..period].iter().sum();
117 let mut atr = initial_sum / period_dec;
118
119 let mut atr_values = Vec::with_capacity(candles.len());
120 for _ in 0..period {
121 atr_values.push(atr);
122 }
123 for tr in true_ranges.iter().skip(period) {
124 atr = (atr * (period_dec - Decimal::ONE) + *tr) / period_dec;
125 atr_values.push(atr);
126 }
127 atr_values
128}
129
130fn ratchet_bands(
132 basic_lower: Decimal,
133 basic_upper: Decimal,
134 prev_lower: Decimal,
135 prev_upper: Decimal,
136 prev_close: Decimal,
137) -> (Decimal, Decimal) {
138 let lower = if basic_lower > prev_lower || prev_close < prev_lower {
139 basic_lower
140 } else {
141 prev_lower
142 };
143 let upper = if basic_upper < prev_upper || prev_close > prev_upper {
144 basic_upper
145 } else {
146 prev_upper
147 };
148 (lower, upper)
149}
150
151fn flip_direction(direction: Decimal, close: Decimal, lower: Decimal, upper: Decimal) -> Decimal {
153 if direction == Decimal::ONE {
154 if close < lower {
155 -Decimal::ONE
156 } else {
157 Decimal::ONE
158 }
159 } else if close > upper {
160 Decimal::ONE
161 } else {
162 -Decimal::ONE
163 }
164}
165
166impl Indicator for Supertrend {
167 fn name(&self) -> &str {
168 &self.name
169 }
170
171 fn warmup_period(&self) -> usize {
172 self.atr_period + 1
173 }
174
175 fn compute(&self, candles: &[Candle]) -> Result<Series, IndicatorError> {
176 let required = self.atr_period + 1;
177 if candles.len() < required {
178 return Err(IndicatorError::InsufficientData {
179 required,
180 actual: candles.len(),
181 });
182 }
183
184 let two = Decimal::from(2);
185 let atr_values = smoothed_atr(candles, self.atr_period);
186
187 let start = self.atr_period;
188 let mut values = Vec::with_capacity(candles.len() - start);
189
190 let mid = (candles[start].high() + candles[start].low()) / two;
191 let mut final_upper = mid + self.multiplier * atr_values[start];
192 let mut final_lower = mid - self.multiplier * atr_values[start];
193 let mut direction: Decimal = if candles[start].close() > mid {
194 Decimal::ONE
195 } else {
196 -Decimal::ONE
197 };
198 values.push((candles[start].timestamp(), direction));
199
200 for i in (start + 1)..candles.len() {
201 let mid_i = (candles[i].high() + candles[i].low()) / two;
202 let basic_upper = mid_i + self.multiplier * atr_values[i];
203 let basic_lower = mid_i - self.multiplier * atr_values[i];
204
205 let (new_lower, new_upper) = ratchet_bands(
206 basic_lower,
207 basic_upper,
208 final_lower,
209 final_upper,
210 candles[i - 1].close(),
211 );
212 direction = flip_direction(direction, candles[i].close(), new_lower, new_upper);
213 final_upper = new_upper;
214 final_lower = new_lower;
215
216 values.push((candles[i].timestamp(), direction));
217 }
218
219 Ok(Series::new(values))
220 }
221}
222
223#[cfg(test)]
224#[path = "supertrend_tests.rs"]
225mod tests;