use quant_primitives::Candle;
use rust_decimal::Decimal;
use crate::error::IndicatorError;
use crate::series::Series;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct KalmanResult {
pub level: Series,
pub slope: Series,
pub innovation_variance: Series,
pub kalman_gain: Series,
pub normalized_innovation: Series,
}
impl KalmanResult {
#[must_use]
pub fn len(&self) -> usize {
self.level.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.level.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct KalmanFilter {
q_level: Decimal,
q_slope: Decimal,
r_obs: Decimal,
warmup: usize,
eiv_window: usize,
name: String,
}
impl KalmanFilter {
pub fn new(
q_level: Decimal,
q_slope: Decimal,
r_obs: Decimal,
warmup: usize,
eiv_window: usize,
) -> Result<Self, IndicatorError> {
if q_level <= Decimal::ZERO {
return Err(IndicatorError::InvalidParameter {
message: "q_level must be > 0".to_string(),
});
}
if q_slope <= Decimal::ZERO {
return Err(IndicatorError::InvalidParameter {
message: "q_slope must be > 0".to_string(),
});
}
if r_obs <= Decimal::ZERO {
return Err(IndicatorError::InvalidParameter {
message: "r_obs must be > 0".to_string(),
});
}
if warmup == 0 {
return Err(IndicatorError::InvalidParameter {
message: "warmup must be >= 1".to_string(),
});
}
if eiv_window < 2 {
return Err(IndicatorError::InvalidParameter {
message: "eiv_window must be >= 2".to_string(),
});
}
Ok(Self {
q_level,
q_slope,
r_obs,
warmup,
eiv_window,
name: format!(
"Kalman({},{},{},{},{})",
q_level, q_slope, r_obs, warmup, eiv_window
),
})
}
#[must_use]
pub fn name(&self) -> &str {
&self.name
}
#[must_use]
pub fn warmup_period(&self) -> usize {
self.warmup
}
pub fn compute(&self, candles: &[Candle]) -> Result<KalmanResult, IndicatorError> {
if candles.len() <= self.warmup {
return Err(IndicatorError::InsufficientData {
required: self.warmup + 1,
actual: candles.len(),
});
}
let n = candles.len();
let out_len = n - self.warmup;
let mut level = candles[0].close();
let mut slope = Decimal::ZERO;
let mut p00 = Decimal::ONE;
let mut p01 = Decimal::ZERO;
let mut p10 = Decimal::ZERO;
let mut p11 = Decimal::ONE;
let mut levels = Vec::with_capacity(out_len);
let mut slopes = Vec::with_capacity(out_len);
let mut innov_vars = Vec::with_capacity(out_len);
let mut gains = Vec::with_capacity(out_len);
let mut norm_innovs = Vec::with_capacity(out_len);
let mut innov_ring: std::collections::VecDeque<Decimal> =
std::collections::VecDeque::with_capacity(self.eiv_window);
let mut empirical_var;
for (i, candle) in candles.iter().enumerate() {
let z = candle.close();
let level_pred = level + slope;
let slope_pred = slope;
let innovation = z - level_pred;
innov_ring.push_back(innovation);
if innov_ring.len() > self.eiv_window {
innov_ring.pop_front();
}
empirical_var = rolling_variance(&innov_ring);
let surprise = Decimal::ONE + decimal_div(empirical_var, self.r_obs);
let q_level_t = self.q_level * surprise;
let q_slope_t = self.q_slope * surprise;
let pp00 = p00 + p10 + p01 + p11 + q_level_t;
let pp01 = p01 + p11;
let pp10 = p10 + p11;
let pp11 = p11 + q_slope_t;
let s = pp00 + self.r_obs;
let k0 = decimal_div(pp00, s);
let k1 = decimal_div(pp10, s);
level = level_pred + k0 * innovation;
slope = slope_pred + k1 * innovation;
p00 = (Decimal::ONE - k0) * pp00;
p01 = (Decimal::ONE - k0) * pp01;
p10 = pp10 - k1 * pp00;
p11 = pp11 - k1 * pp01;
if i >= self.warmup {
let ts = candle.timestamp();
levels.push((ts, level));
slopes.push((ts, slope));
innov_vars.push((ts, empirical_var));
let gain_clamped = clamp_decimal(k0, Decimal::ZERO, Decimal::ONE);
gains.push((ts, gain_clamped));
let norm_innov = normalized_innovation(innovation, empirical_var);
norm_innovs.push((ts, norm_innov));
}
}
Ok(KalmanResult {
level: Series::new(levels),
slope: Series::new(slopes),
innovation_variance: Series::new(innov_vars),
kalman_gain: Series::new(gains),
normalized_innovation: Series::new(norm_innovs),
})
}
}
fn rolling_variance(buf: &std::collections::VecDeque<Decimal>) -> Decimal {
if buf.len() < 2 {
return Decimal::ZERO;
}
let n = Decimal::from(buf.len() as u64);
let sum: Decimal = buf.iter().copied().sum();
let sum_sq: Decimal = buf.iter().map(|v| *v * *v).sum();
let mean = sum / n;
let var = sum_sq / n - mean * mean;
if var < Decimal::ZERO {
Decimal::ZERO
} else {
var
}
}
fn normalized_innovation(innovation: Decimal, empirical_var: Decimal) -> Decimal {
if empirical_var <= Decimal::ZERO {
return Decimal::ZERO;
}
let sqrt_v = decimal_sqrt(empirical_var);
if sqrt_v <= Decimal::ZERO {
return Decimal::ZERO;
}
decimal_div(innovation, sqrt_v)
}
fn decimal_div(num: Decimal, den: Decimal) -> Decimal {
if den == Decimal::ZERO {
return Decimal::ZERO;
}
num / den
}
fn clamp_decimal(val: Decimal, lo: Decimal, hi: Decimal) -> Decimal {
if val < lo {
lo
} else if val > hi {
hi
} else {
val
}
}
fn decimal_sqrt(value: Decimal) -> Decimal {
if value <= Decimal::ZERO {
return Decimal::ZERO;
}
if value == Decimal::ONE {
return Decimal::ONE;
}
let mut guess = value / Decimal::TWO;
let epsilon = Decimal::new(1, 12);
for _ in 0..100 {
if guess <= Decimal::ZERO {
return Decimal::ZERO;
}
let next = (guess + value / guess) / Decimal::TWO;
let diff = if next > guess {
next - guess
} else {
guess - next
};
guess = next;
if diff < epsilon {
break;
}
}
guess
}
#[cfg(test)]
#[path = "kalman_tests.rs"]
mod tests;