use super::*;
use crate::test_helpers::helpers::{make_candle, ts};
use rust_decimal::Decimal;
use rust_decimal_macros::dec;
fn make_candles_constant(price: Decimal, count: usize) -> Vec<quant_primitives::Candle> {
(0..count)
.map(|i| make_candle(price, ts(i as i64)))
.collect()
}
fn make_candles_linear(
start: Decimal,
step: Decimal,
count: usize,
) -> Vec<quant_primitives::Candle> {
(0..count)
.map(|i| {
let price = start + step * Decimal::from(i as u64);
make_candle(price, ts(i as i64))
})
.collect()
}
fn make_candles_choppy(
base: Decimal,
amplitude: Decimal,
count: usize,
) -> Vec<quant_primitives::Candle> {
(0..count)
.map(|i| {
let sign = if (i * 7 + 3) % 5 < 3 {
Decimal::ONE
} else {
Decimal::NEGATIVE_ONE
};
let magnitude = Decimal::from((i * 13 + 7) % 10) / Decimal::TEN;
let noise = sign * amplitude * magnitude;
make_candle(base + noise, ts(i as i64))
})
.collect()
}
#[test]
fn test_kalman_warmup_period_matches_config() {
let kf = KalmanFilter::new(dec!(0.01), dec!(0.001), dec!(1), 42, 20).unwrap();
assert_eq!(kf.warmup_period(), 42);
}
#[test]
fn test_kalman_invalid_params_zero_q_level() {
let result = KalmanFilter::new(dec!(0), dec!(0.01), dec!(1), 50, 20);
assert!(matches!(
result,
Err(IndicatorError::InvalidParameter { .. })
));
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("q_level"),
"error should mention q_level: {msg}"
);
}
#[test]
fn test_kalman_invalid_params_negative_q_slope() {
let result = KalmanFilter::new(dec!(0.01), dec!(-0.01), dec!(1), 50, 20);
assert!(matches!(
result,
Err(IndicatorError::InvalidParameter { .. })
));
}
#[test]
fn test_kalman_invalid_params_zero_r_obs() {
let result = KalmanFilter::new(dec!(0.01), dec!(0.001), dec!(0), 50, 20);
assert!(matches!(
result,
Err(IndicatorError::InvalidParameter { .. })
));
let msg = result.unwrap_err().to_string();
assert!(msg.contains("r_obs"), "error should mention r_obs: {msg}");
}
#[test]
fn test_kalman_invalid_params_zero_warmup() {
let result = KalmanFilter::new(dec!(0.01), dec!(0.001), dec!(1), 0, 20);
assert!(matches!(
result,
Err(IndicatorError::InvalidParameter { .. })
));
}
#[test]
fn test_kalman_output_lengths_match() {
let candles = make_candles_linear(dec!(100), dec!(1), 100);
let kf = KalmanFilter::new(dec!(0.01), dec!(0.001), dec!(1), 20, 20).unwrap();
let result = kf.compute(&candles).unwrap();
let expected_len = 80; assert_eq!(result.level.len(), expected_len);
assert_eq!(result.slope.len(), expected_len);
assert_eq!(result.innovation_variance.len(), expected_len);
assert_eq!(result.kalman_gain.len(), expected_len);
assert_eq!(result.normalized_innovation.len(), expected_len);
assert_eq!(result.len(), expected_len);
}
#[test]
fn test_kalman_insufficient_data_error() {
let candles = make_candles_constant(dec!(50), 10);
let kf = KalmanFilter::new(dec!(0.01), dec!(0.001), dec!(1), 50, 20).unwrap();
let result = kf.compute(&candles);
assert!(matches!(
result,
Err(IndicatorError::InsufficientData { .. })
));
}
#[test]
fn test_kalman_exact_warmup_plus_one_produces_one_value() {
let candles = make_candles_constant(dec!(50), 21);
let kf = KalmanFilter::new(dec!(0.01), dec!(0.001), dec!(1), 20, 20).unwrap();
let result = kf.compute(&candles).unwrap();
assert_eq!(result.len(), 1);
}
#[test]
fn test_kalman_constant_price_level_converges() {
let candles = make_candles_constant(dec!(50), 100);
let kf = KalmanFilter::new(dec!(0.01), dec!(0.001), dec!(1), 20, 20).unwrap();
let result = kf.compute(&candles).unwrap();
for (_, val) in result.level.values() {
let diff = (*val - dec!(50)).abs();
assert!(
diff < dec!(0.1),
"level should converge to 50, got {val} (diff={diff})"
);
}
}
#[test]
fn test_kalman_linear_trend_positive_slope() {
let candles = make_candles_linear(dec!(100), dec!(1), 100);
let kf = KalmanFilter::new(dec!(0.01), dec!(0.001), dec!(1), 20, 20).unwrap();
let result = kf.compute(&candles).unwrap();
for (_, val) in result.slope.values() {
assert!(
*val > Decimal::ZERO,
"slope should be positive on uptrend, got {val}"
);
}
}
#[test]
fn test_kalman_linear_trend_level_tracks() {
let candles = make_candles_linear(dec!(100), dec!(1), 100);
let kf = KalmanFilter::new(dec!(0.01), dec!(0.001), dec!(1), 20, 20).unwrap();
let result = kf.compute(&candles).unwrap();
let price_range = dec!(99); let max_lag = price_range * dec!(0.05);
for (i, (_, val)) in result.level.values().iter().enumerate() {
let actual_price = dec!(100) + Decimal::from((i + 20) as u64);
let diff = (*val - actual_price).abs();
assert!(
diff < max_lag,
"level[{i}] = {val}, expected ~{actual_price}, diff = {diff} exceeds {max_lag}"
);
}
}
#[test]
fn test_kalman_innovation_variance_trending_low() {
let candles = make_candles_linear(dec!(100), dec!(1), 200);
let kf = KalmanFilter::new(dec!(0.01), dec!(0.001), dec!(1), 50, 20).unwrap();
let result = kf.compute(&candles).unwrap();
let values = result.innovation_variance.decimal_values();
let last_50 = &values[values.len() - 50..];
for (i, val) in last_50.iter().enumerate() {
assert!(
*val < dec!(0.01),
"innovation_variance[{}] = {} should be < 0.01 on linear series",
values.len() - 50 + i,
val
);
}
}
#[test]
fn test_kalman_innovation_variance_choppy_high() {
let trending = make_candles_linear(dec!(100), dec!(1), 200);
let choppy = make_candles_choppy(dec!(100), dec!(5), 200);
let kf = KalmanFilter::new(dec!(0.01), dec!(0.001), dec!(1), 50, 20).unwrap();
let trend_result = kf.compute(&trending).unwrap();
let choppy_result = kf.compute(&choppy).unwrap();
let trend_vars = trend_result.innovation_variance.decimal_values();
let choppy_vars = choppy_result.innovation_variance.decimal_values();
let trend_mean = mean_last_n(&trend_vars, 50);
let choppy_mean = mean_last_n(&choppy_vars, 50);
assert!(
choppy_mean > trend_mean * dec!(5),
"choppy mean ({choppy_mean}) should be > 5x trending mean ({trend_mean})"
);
}
#[test]
fn test_kalman_normalized_innovation_regime_separation() {
let trending = make_candles_linear(dec!(100), dec!(1), 200);
let choppy = make_candles_choppy(dec!(100), dec!(5), 200);
let kf = KalmanFilter::new(dec!(0.01), dec!(0.001), dec!(1), 50, 20).unwrap();
let trend_result = kf.compute(&trending).unwrap();
let choppy_result = kf.compute(&choppy).unwrap();
let trend_ni = trend_result.normalized_innovation.decimal_values();
let choppy_ni = choppy_result.normalized_innovation.decimal_values();
let trend_abs_mean = mean_abs_last_n(&trend_ni, 50);
let choppy_abs_mean = mean_abs_last_n(&choppy_ni, 50);
assert!(
trend_abs_mean < dec!(1),
"trending |norm_innov| mean ({trend_abs_mean}) should be < 1.0"
);
assert!(
choppy_abs_mean > trend_abs_mean,
"choppy |norm_innov| mean ({choppy_abs_mean}) should be > trending ({trend_abs_mean})"
);
}
#[test]
fn test_kalman_gain_bounded() {
let candles = make_candles_choppy(dec!(100), dec!(5), 200);
let kf = KalmanFilter::new(dec!(0.01), dec!(0.001), dec!(1), 20, 20).unwrap();
let result = kf.compute(&candles).unwrap();
for (_, val) in result.kalman_gain.values() {
assert!(
*val >= Decimal::ZERO && *val <= Decimal::ONE,
"kalman_gain should be in [0, 1], got {val}"
);
}
}
#[test]
fn test_kalman_gain_stabilizes_in_trend() {
let candles = make_candles_linear(dec!(100), dec!(1), 200);
let kf = KalmanFilter::new(dec!(0.01), dec!(0.001), dec!(1), 20, 20).unwrap();
let result = kf.compute(&candles).unwrap();
let gains = result.kalman_gain.decimal_values();
let last_20 = &gains[gains.len() - 20..];
let mean_gain = mean_last_n(&gains, 20);
let variance: Decimal = last_20
.iter()
.map(|g| {
let d = *g - mean_gain;
d * d
})
.sum::<Decimal>()
/ Decimal::from(20u64);
assert!(
variance < dec!(0.001),
"gain variance ({variance}) should be very low in stable trend"
);
}
#[test]
fn test_kalman_gain_increases_after_regime_shift() {
let mut candles = make_candles_linear(dec!(100), dec!(1), 100);
let choppy = make_candles_choppy(dec!(200), dec!(5), 100);
for (i, c) in choppy.iter().enumerate() {
candles.push(make_candle(c.close(), ts((100 + i) as i64)));
}
let kf = KalmanFilter::new(dec!(0.01), dec!(0.001), dec!(1), 20, 20).unwrap();
let result = kf.compute(&candles).unwrap();
let gains = result.kalman_gain.decimal_values();
let stable_mean = mean_slice(&gains, 60, 80);
let post_shift_mean = mean_slice(&gains, 85, 90);
assert!(
post_shift_mean > stable_mean,
"mean(gain[105..110])={post_shift_mean} should be > mean(gain[80..100])={stable_mean}"
);
}
#[test]
fn test_eiv_window_below_2_rejected() {
let result = KalmanFilter::new(dec!(0.01), dec!(0.001), dec!(1), 20, 1);
assert!(matches!(
result,
Err(IndicatorError::InvalidParameter { .. })
));
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("eiv_window"),
"error should mention eiv_window: {msg}"
);
}
#[test]
fn test_eiv_window_changes_output() {
let candles = make_candles_choppy(dec!(100), dec!(5), 200);
let kf_short = KalmanFilter::new(dec!(0.01), dec!(0.001), dec!(1), 20, 10).unwrap();
let kf_long = KalmanFilter::new(dec!(0.01), dec!(0.001), dec!(1), 20, 50).unwrap();
let result_short = kf_short.compute(&candles).unwrap();
let result_long = kf_long.compute(&candles).unwrap();
let vars_short = result_short.innovation_variance.decimal_values();
let vars_long = result_long.innovation_variance.decimal_values();
let differs = vars_short.iter().zip(vars_long.iter()).any(|(a, b)| a != b);
assert!(
differs,
"different eiv_window values must produce different innovation_variance series"
);
}
#[test]
fn test_eiv_window_shorter_detects_regime_shift_faster() {
let mut candles = make_candles_linear(dec!(100), dec!(1), 100);
let choppy = make_candles_choppy(dec!(200), dec!(5), 100);
for (i, c) in choppy.iter().enumerate() {
candles.push(make_candle(c.close(), ts((100 + i) as i64)));
}
let kf_short = KalmanFilter::new(dec!(0.01), dec!(0.001), dec!(1), 20, 10).unwrap();
let kf_long = KalmanFilter::new(dec!(0.01), dec!(0.001), dec!(1), 20, 50).unwrap();
let result_short = kf_short.compute(&candles).unwrap();
let result_long = kf_long.compute(&candles).unwrap();
let vars_short = result_short.innovation_variance.decimal_values();
let vars_long = result_long.innovation_variance.decimal_values();
let short_near_shift = mean_slice(&vars_short, 85, 95);
let long_near_shift = mean_slice(&vars_long, 85, 95);
assert!(
short_near_shift > long_near_shift,
"eiv_window=10 near shift ({short_near_shift}) should be > eiv_window=50 ({long_near_shift})"
);
}
fn mean_last_n(values: &[Decimal], n: usize) -> Decimal {
let start = values.len().saturating_sub(n);
let slice = &values[start..];
let sum: Decimal = slice.iter().copied().sum();
sum / Decimal::from(slice.len() as u64)
}
fn mean_abs_last_n(values: &[Decimal], n: usize) -> Decimal {
let start = values.len().saturating_sub(n);
let slice = &values[start..];
let sum: Decimal = slice.iter().map(|v| v.abs()).sum();
sum / Decimal::from(slice.len() as u64)
}
fn mean_slice(values: &[Decimal], from: usize, to: usize) -> Decimal {
let to = to.min(values.len());
let from = from.min(to);
let slice = &values[from..to];
if slice.is_empty() {
return Decimal::ZERO;
}
let sum: Decimal = slice.iter().copied().sum();
sum / Decimal::from(slice.len() as u64)
}