use crate::indicator_error::IndicatorError;
use tracing::instrument;
#[derive(Debug, PartialEq, Clone, Copy)]
pub enum TrendType {
Up,
Down,
}
#[derive(Debug)]
pub struct SupertrendData {
pub trend: TrendType,
pub upper_band: f64,
pub lower_band: f64,
}
pub struct SettingSupertrend {
pub factor: f64,
}
#[instrument(level = "trace", skip_all, ret)]
pub fn calculate_supertrend(
candle_data: &[f64],
median_price_data: &[f64],
atr_data: &[f64],
setting: &SettingSupertrend,
) -> Result<Vec<SupertrendData>, IndicatorError> {
if candle_data.is_empty() || median_price_data.is_empty() || atr_data.is_empty() {
return Err(IndicatorError::EmptyData);
}
if candle_data.len() != median_price_data.len() || median_price_data.len() != atr_data.len() {
return Err(IndicatorError::DifferentDataLength);
}
if setting.factor <= 0.0 {
return Err(IndicatorError::ImproperSetting);
}
let mut supertrend: Vec<SupertrendData> = Vec::with_capacity(candle_data.len());
{
let first_upper: f64 = median_price_data[0] + (setting.factor * atr_data[0]);
let first_lower: f64 = median_price_data[0] - (setting.factor * atr_data[0]);
supertrend.push(SupertrendData {
trend: if candle_data[0] > first_lower {
TrendType::Up
} else {
TrendType::Down
},
upper_band: first_upper,
lower_band: first_lower,
});
}
for i in 1..candle_data.len() {
let prev: &SupertrendData = &supertrend[i - 1];
let current_upper_calc: f64 = median_price_data[i] + (setting.factor * atr_data[i]);
let current_lower_calc: f64 = median_price_data[i] - (setting.factor * atr_data[i]);
let upper_band: f64 =
if current_upper_calc < prev.upper_band || candle_data[i - 1] > prev.upper_band {
current_upper_calc
} else {
prev.upper_band
};
let lower_band: f64 =
if current_lower_calc > prev.lower_band || candle_data[i - 1] < prev.lower_band {
current_lower_calc
} else {
prev.lower_band
};
let trend = match prev.trend {
TrendType::Up => {
if candle_data[i] <= lower_band {
TrendType::Down
} else {
TrendType::Up
}
}
TrendType::Down => {
if candle_data[i] >= upper_band {
TrendType::Up
} else {
TrendType::Down
}
}
};
supertrend.push(SupertrendData {
trend,
upper_band,
lower_band,
});
}
Ok(supertrend)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_custom_data() -> (Vec<f64>, Vec<f64>, Vec<f64>) {
let close = vec![10.0, 11.0, 12.0, 9.0, 8.0, 13.0];
let median = vec![10.0, 10.5, 11.5, 9.5, 8.5, 12.5];
let atr = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
(close, median, atr)
}
#[test]
fn test_calculate_supertrend_basic() {
let (close, median, atr) = make_custom_data();
let setting = SettingSupertrend { factor: 2.0 };
let result = calculate_supertrend(&close, &median, &atr, &setting).unwrap();
assert_eq!(result.len(), 6);
assert_eq!(result[0].trend, TrendType::Up);
assert!((result[0].upper_band - 12.0).abs() < 1e-10);
assert!((result[0].lower_band - 8.0).abs() < 1e-10);
}
#[test]
fn test_calculate_supertrend_trend_flip_up_to_down() {
let (close, median, atr) = make_custom_data();
let setting = SettingSupertrend { factor: 2.0 };
let result = calculate_supertrend(&close, &median, &atr, &setting).unwrap();
assert_eq!(result[3].trend, TrendType::Down);
assert_eq!(result[4].trend, TrendType::Down);
}
#[test]
fn test_calculate_supertrend_band_continuity() {
let close = vec![10.0, 10.0, 10.0, 10.0];
let median = vec![10.0, 10.0, 5.0, 5.0];
let atr = vec![1.0, 1.0, 1.0, 1.0];
let setting = SettingSupertrend { factor: 2.0 };
let result = calculate_supertrend(&close, &median, &atr, &setting).unwrap();
assert_eq!(result.len(), 4);
assert_eq!(result[0].trend, TrendType::Up);
assert!((result[2].upper_band - 7.0).abs() < 1e-10);
assert!((result[2].lower_band - 8.0).abs() < 1e-10);
}
#[test]
fn test_calculate_supertrend_empty() {
assert!(matches!(
calculate_supertrend(&[], &[1.0], &[1.0], &SettingSupertrend { factor: 2.0 })
.unwrap_err(),
IndicatorError::EmptyData
));
}
#[test]
fn test_calculate_supertrend_length_mismatch() {
assert!(matches!(
calculate_supertrend(
&[1.0, 2.0],
&[1.0],
&[1.0],
&SettingSupertrend { factor: 2.0 }
)
.unwrap_err(),
IndicatorError::DifferentDataLength
));
}
#[test]
fn test_calculate_supertrend_zero_factor() {
assert!(matches!(
calculate_supertrend(&[1.0], &[1.0], &[1.0], &SettingSupertrend { factor: 0.0 })
.unwrap_err(),
IndicatorError::ImproperSetting
));
}
#[test]
fn test_calculate_supertrend_negative_factor() {
assert!(matches!(
calculate_supertrend(&[1.0], &[1.0], &[1.0], &SettingSupertrend { factor: -1.0 })
.unwrap_err(),
IndicatorError::ImproperSetting
));
}
}