use chrono::{TimeZone, Utc};
use rust_decimal::Decimal;
use rust_decimal_macros::dec;
use crate::Indicator;
use super::Supertrend;
fn ts(i: i64) -> chrono::DateTime<Utc> {
Utc.timestamp_opt(1_700_000_000 + i * 3600, 0).unwrap()
}
fn candle(
i: i64,
open: Decimal,
high: Decimal,
low: Decimal,
close: Decimal,
) -> quant_primitives::Candle {
quant_primitives::Candle::new(open, high, low, close, dec!(1000), ts(i)).unwrap()
}
#[test]
fn new_rejects_zero_period() {
let err = Supertrend::new(0, dec!(2)).unwrap_err();
assert!(err.to_string().contains("period must be > 0"));
}
#[test]
fn new_rejects_zero_multiplier() {
let err = Supertrend::new(14, dec!(0)).unwrap_err();
assert!(err.to_string().contains("multiplier must be > 0"));
}
#[test]
fn new_rejects_negative_multiplier() {
let err = Supertrend::new(14, dec!(-1)).unwrap_err();
assert!(err.to_string().contains("multiplier must be > 0"));
}
#[test]
fn insufficient_data() {
let st = Supertrend::new(14, dec!(2)).unwrap();
let candles: Vec<_> = (0..10)
.map(|i| candle(i, dec!(100), dec!(110), dec!(90), dec!(100)))
.collect();
let err = st.compute(&candles).unwrap_err();
assert!(err.to_string().contains("insufficient data"));
}
#[test]
fn name_includes_params() {
let st = Supertrend::new(10, dec!(2.5)).unwrap();
assert_eq!(st.name(), "Supertrend(10,2.5)");
}
#[test]
fn warmup_period_is_atr_period_plus_one() {
let st = Supertrend::new(10, dec!(2)).unwrap();
assert_eq!(st.warmup_period(), 11);
}
#[test]
fn uptrend_detected_in_rising_series() {
let candles: Vec<_> = (0..30)
.map(|i| {
let base = Decimal::from(100 + i * 2);
candle(i, base, base + dec!(5), base - dec!(3), base + dec!(3))
})
.collect();
let st = Supertrend::new(10, dec!(2)).unwrap();
let series = st.compute(&candles).unwrap();
let last_value = series.values().last().unwrap().1;
assert_eq!(
last_value,
dec!(1),
"Expected uptrend (+1) for rising series"
);
}
#[test]
fn downtrend_detected_in_falling_series() {
let candles: Vec<_> = (0..30)
.map(|i| {
let base = Decimal::from(200 - i * 2);
candle(i, base, base + dec!(3), base - dec!(5), base - dec!(3))
})
.collect();
let st = Supertrend::new(10, dec!(2)).unwrap();
let series = st.compute(&candles).unwrap();
let last_value = series.values().last().unwrap().1;
assert_eq!(
last_value,
dec!(-1),
"Expected downtrend (-1) for falling series"
);
}
#[test]
fn trend_flip_on_decisive_move() {
let mut candles = Vec::with_capacity(40);
for i in 0..20 {
let base = Decimal::from(100 + i * 2);
candles.push(candle(
i,
base,
base + dec!(3),
base - dec!(2),
base + dec!(2),
));
}
for i in 0..20 {
let idx = (20 + i) as i64;
let base = Decimal::from(140 - i * 4); candles.push(candle(
idx,
base,
base + dec!(2),
base - dec!(5),
base - dec!(4),
));
}
let st = Supertrend::new(10, dec!(2)).unwrap();
let series = st.compute(&candles).unwrap();
let values = series.values();
let early_direction = values[5].1; assert_eq!(early_direction, dec!(1), "Expected uptrend early");
let late_direction = values.last().unwrap().1;
assert_eq!(
late_direction,
dec!(-1),
"Expected downtrend after sharp drop"
);
let flips: Vec<_> = values.windows(2).filter(|w| w[0].1 != w[1].1).collect();
assert!(!flips.is_empty(), "Expected at least one direction flip");
}
#[test]
fn holds_through_minor_pullback() {
let mut candles = Vec::with_capacity(30);
for i in 0..15 {
let base = Decimal::from(100 + i * 3);
candles.push(candle(
i,
base,
base + dec!(4),
base - dec!(2),
base + dec!(3),
));
}
for i in 0..5 {
let idx = (15 + i) as i64;
let base = Decimal::from(145 - i); candles.push(candle(
idx,
base,
base + dec!(3),
base - dec!(2),
base - dec!(1),
));
}
for i in 0..10 {
let idx = (20 + i) as i64;
let base = Decimal::from(140 + i * 3);
candles.push(candle(
idx,
base,
base + dec!(4),
base - dec!(2),
base + dec!(3),
));
}
let st = Supertrend::new(10, dec!(3)).unwrap(); let series = st.compute(&candles).unwrap();
let last_direction = series.values().last().unwrap().1;
assert_eq!(
last_direction,
dec!(1),
"Wide-multiplier Supertrend should hold through minor pullback"
);
}
#[test]
fn multiplier_affects_sensitivity() {
let candles: Vec<_> = (0..40)
.map(|i| {
let base = Decimal::from(100 + i);
let noise = if i % 3 == 0 { dec!(8) } else { dec!(-5) };
candle(
i,
base,
base + dec!(6) + noise,
base - dec!(6),
base + noise,
)
})
.collect();
let st_tight = Supertrend::new(10, dec!(1)).unwrap();
let st_wide = Supertrend::new(10, dec!(3)).unwrap();
let series_tight = st_tight.compute(&candles).unwrap();
let series_wide = st_wide.compute(&candles).unwrap();
let count_flips = |series: &crate::Series| -> usize {
series
.values()
.windows(2)
.filter(|w| w[0].1 != w[1].1)
.count()
};
let flips_tight = count_flips(&series_tight);
let flips_wide = count_flips(&series_wide);
assert!(
flips_tight >= flips_wide,
"Tight multiplier ({} flips) should have >= flips than wide ({} flips)",
flips_tight,
flips_wide
);
}
#[test]
fn output_length_is_candles_minus_atr_period() {
let candles: Vec<_> = (0..30)
.map(|i| candle(i, dec!(100), dec!(110), dec!(90), dec!(100)))
.collect();
let st = Supertrend::new(10, dec!(2)).unwrap();
let series = st.compute(&candles).unwrap();
assert_eq!(
series.len(),
30 - 10,
"Output length = candles - atr_period"
);
}
#[test]
fn output_values_are_only_plus_or_minus_one() {
let candles: Vec<_> = (0..30)
.map(|i| {
let base = Decimal::from(100 + i);
candle(i, base, base + dec!(5), base - dec!(3), base + dec!(2))
})
.collect();
let st = Supertrend::new(10, dec!(2)).unwrap();
let series = st.compute(&candles).unwrap();
for (_, val) in series.values() {
assert!(
*val == dec!(1) || *val == dec!(-1),
"Direction must be +1 or -1, got {}",
val
);
}
}
#[test]
fn known_good_reference_series() {
let candles = vec![
candle(0, dec!(100), dec!(105), dec!(95), dec!(102)),
candle(1, dec!(102), dec!(108), dec!(98), dec!(106)),
candle(2, dec!(106), dec!(112), dec!(100), dec!(104)),
candle(3, dec!(104), dec!(110), dec!(99), dec!(108)),
candle(4, dec!(108), dec!(115), dec!(103), dec!(113)),
candle(5, dec!(113), dec!(118), dec!(107), dec!(109)),
candle(6, dec!(109), dec!(112), dec!(80), dec!(85)),
];
let st = Supertrend::new(3, dec!(2)).unwrap();
let series = st.compute(&candles).unwrap();
let vals: Vec<Decimal> = series.values().iter().map(|(_, v)| *v).collect();
assert_eq!(vals.len(), 4, "Expected 4 output values");
assert_eq!(vals[0], dec!(1), "Bar 3: uptrend (close > mid)");
assert_eq!(vals[1], dec!(1), "Bar 4: uptrend holds");
assert_eq!(vals[2], dec!(1), "Bar 5: uptrend holds through pullback");
assert_eq!(
vals[3],
dec!(-1),
"Bar 6: downtrend after close drops below lower band"
);
}
#[test]
fn deterministic_replay() {
let candles: Vec<_> = (0..25)
.map(|i| {
let base = Decimal::from(100 + i * 2);
candle(i, base, base + dec!(5), base - dec!(3), base + dec!(3))
})
.collect();
let st = Supertrend::new(10, dec!(2.5)).unwrap();
let run1 = st.compute(&candles).unwrap();
let run2 = st.compute(&candles).unwrap();
assert_eq!(
run1.values(),
run2.values(),
"Identical inputs must produce identical outputs"
);
}
#[test]
fn band_ratchet_invariant() {
use super::{ratchet_bands, smoothed_atr};
let candles: Vec<_> = (0..20)
.map(|i| {
let base = Decimal::from(100 + i * 2);
candle(i, base, base + dec!(5), base - dec!(3), base + dec!(3))
})
.collect();
let period = 5;
let mult = dec!(2);
let two = Decimal::from(2);
let atr_vals = smoothed_atr(&candles, period);
let start = period;
let mid = (candles[start].high() + candles[start].low()) / two;
let mut prev_lower = mid - mult * atr_vals[start];
let mut prev_upper = mid + mult * atr_vals[start];
for i in (start + 1)..candles.len() {
let mid_i = (candles[i].high() + candles[i].low()) / two;
let basic_upper = mid_i + mult * atr_vals[i];
let basic_lower = mid_i - mult * atr_vals[i];
let (new_lower, new_upper) = ratchet_bands(
basic_lower,
basic_upper,
prev_lower,
prev_upper,
candles[i - 1].close(),
);
assert!(
new_lower >= prev_lower,
"Bar {}: lower band decreased from {} to {} (ratchet violated)",
i,
prev_lower,
new_lower,
);
let _ = new_upper;
prev_lower = new_lower;
prev_upper = new_upper;
}
}
#[test]
fn accessor_roundtrip() {
let st = Supertrend::new(14, dec!(3.5)).unwrap();
assert_eq!(st.atr_period(), 14);
assert_eq!(st.multiplier(), dec!(3.5));
}