use quant_primitives::Candle;
use rust_decimal::Decimal;
use crate::error::IndicatorError;
use crate::indicator::Indicator;
use crate::series::Series;
use crate::true_range;
#[derive(Debug, Clone)]
pub struct Supertrend {
atr_period: usize,
multiplier: Decimal,
name: String,
}
impl Supertrend {
pub fn new(atr_period: usize, multiplier: Decimal) -> Result<Self, IndicatorError> {
if atr_period == 0 {
return Err(IndicatorError::InvalidParameter {
message: "Supertrend ATR period must be > 0".to_string(),
});
}
if multiplier <= Decimal::ZERO {
return Err(IndicatorError::InvalidParameter {
message: "Supertrend multiplier must be > 0".to_string(),
});
}
Ok(Self {
atr_period,
multiplier,
name: format!("Supertrend({},{})", atr_period, multiplier),
})
}
#[must_use]
pub fn atr_period(&self) -> usize {
self.atr_period
}
#[must_use]
pub fn multiplier(&self) -> Decimal {
self.multiplier
}
}
fn smoothed_atr(candles: &[Candle], period: usize) -> Vec<Decimal> {
let period_dec = Decimal::from(period as u64);
let mut true_ranges = Vec::with_capacity(candles.len());
true_ranges.push(candles[0].high() - candles[0].low());
for i in 1..candles.len() {
true_ranges.push(true_range(&candles[i], candles[i - 1].close()));
}
let initial_sum: Decimal = true_ranges[..period].iter().sum();
let mut atr = initial_sum / period_dec;
let mut atr_values = Vec::with_capacity(candles.len());
for _ in 0..period {
atr_values.push(atr);
}
for tr in true_ranges.iter().skip(period) {
atr = (atr * (period_dec - Decimal::ONE) + *tr) / period_dec;
atr_values.push(atr);
}
atr_values
}
fn ratchet_bands(
basic_lower: Decimal,
basic_upper: Decimal,
prev_lower: Decimal,
prev_upper: Decimal,
prev_close: Decimal,
) -> (Decimal, Decimal) {
let lower = if basic_lower > prev_lower || prev_close < prev_lower {
basic_lower
} else {
prev_lower
};
let upper = if basic_upper < prev_upper || prev_close > prev_upper {
basic_upper
} else {
prev_upper
};
(lower, upper)
}
fn flip_direction(direction: Decimal, close: Decimal, lower: Decimal, upper: Decimal) -> Decimal {
if direction == Decimal::ONE {
if close < lower {
-Decimal::ONE
} else {
Decimal::ONE
}
} else if close > upper {
Decimal::ONE
} else {
-Decimal::ONE
}
}
impl Indicator for Supertrend {
fn name(&self) -> &str {
&self.name
}
fn warmup_period(&self) -> usize {
self.atr_period + 1
}
fn compute(&self, candles: &[Candle]) -> Result<Series, IndicatorError> {
let required = self.atr_period + 1;
if candles.len() < required {
return Err(IndicatorError::InsufficientData {
required,
actual: candles.len(),
});
}
let two = Decimal::from(2);
let atr_values = smoothed_atr(candles, self.atr_period);
let start = self.atr_period;
let mut values = Vec::with_capacity(candles.len() - start);
let mid = (candles[start].high() + candles[start].low()) / two;
let mut final_upper = mid + self.multiplier * atr_values[start];
let mut final_lower = mid - self.multiplier * atr_values[start];
let mut direction: Decimal = if candles[start].close() > mid {
Decimal::ONE
} else {
-Decimal::ONE
};
values.push((candles[start].timestamp(), direction));
for i in (start + 1)..candles.len() {
let mid_i = (candles[i].high() + candles[i].low()) / two;
let basic_upper = mid_i + self.multiplier * atr_values[i];
let basic_lower = mid_i - self.multiplier * atr_values[i];
let (new_lower, new_upper) = ratchet_bands(
basic_lower,
basic_upper,
final_lower,
final_upper,
candles[i - 1].close(),
);
direction = flip_direction(direction, candles[i].close(), new_lower, new_upper);
final_upper = new_upper;
final_lower = new_lower;
values.push((candles[i].timestamp(), direction));
}
Ok(Series::new(values))
}
}
#[cfg(test)]
#[path = "supertrend_tests.rs"]
mod tests;