quant-indicators 0.7.0

Pure indicator math library for trading — MA, RSI, Bollinger, MACD, ATR, HRP
Documentation
//! Kalman filter indicator — Local Linear Trend model.
//!
//! Produces five output series from a single state-space model:
//! - `level` — filtered price level
//! - `slope` — trend direction/strength
//! - `innovation_variance` — prediction error variance (regime signal)
//! - `kalman_gain` — adaptation speed (level component)
//! - `normalized_innovation` — innovation / sqrt(innovation_variance)
//!
//! # State-Space Model
//!
//! ```text
//! State:       x_t = [level_t, slope_t]'
//! Transition:  x_{t+1} = [[1,1],[0,1]] * x_t + eta_t,  eta ~ N(0, Q)
//! Observation: y_t = [1, 0] * x_t + eps_t,  eps ~ N(0, R)
//! ```
//!
//! # References
//!
//! - Levine & Pedersen (AQR, 2016): all linear trend filters are mathematically equivalent
//! - Benhamou 2018: Sharpe 1.22, whipsaw reduction
//! - Kang 2026: vol-scaled R_t for regime adaptation

use quant_primitives::Candle;
use rust_decimal::Decimal;

use crate::error::IndicatorError;
use crate::series::Series;

/// Kalman filter result containing all five output series.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct KalmanResult {
    /// Filtered price level.
    pub level: Series,
    /// Trend direction and strength.
    pub slope: Series,
    /// Prediction error variance — low = trending, high = choppy.
    pub innovation_variance: Series,
    /// Adaptation speed (level component of Kalman gain, bounded [0, 1]).
    pub kalman_gain: Series,
    /// Innovation divided by sqrt(innovation_variance) — regime separation signal.
    pub normalized_innovation: Series,
}

impl KalmanResult {
    /// Number of output values (post-warmup).
    #[must_use]
    pub fn len(&self) -> usize {
        self.level.len()
    }

    /// Whether the result is empty.
    #[must_use]
    pub fn is_empty(&self) -> bool {
        self.level.is_empty()
    }
}

/// Kalman filter with a Local Linear Trend state-space model.
///
/// # Example
///
/// ```
/// use quant_indicators::KalmanFilter;
/// use quant_primitives::Candle;
/// use chrono::Utc;
/// use rust_decimal_macros::dec;
///
/// let ts = Utc::now();
/// let candles: Vec<Candle> = (0..100).map(|i| {
///     let price = dec!(100) + rust_decimal::Decimal::from(i);
///     Candle::new(price, price, price, price, dec!(1000), ts).unwrap()
/// }).collect();
/// let kf = KalmanFilter::new(dec!(0.01), dec!(0.001), dec!(1), 20, 20).unwrap();
/// let result = kf.compute(&candles).unwrap();
/// assert_eq!(result.len(), 80); // 100 - 20
/// ```
#[derive(Debug, Clone)]
pub struct KalmanFilter {
    q_level: Decimal,
    q_slope: Decimal,
    r_obs: Decimal,
    warmup: usize,
    eiv_window: usize,
    name: String,
}

impl KalmanFilter {
    /// Create a new Kalman filter.
    ///
    /// # Arguments
    ///
    /// * `q_level` — process noise for level state (must be > 0)
    /// * `q_slope` — process noise for slope state (must be > 0)
    /// * `r_obs` — measurement noise (must be > 0)
    /// * `warmup` — bars consumed before output starts (must be ≥ 1)
    /// * `eiv_window` — rolling window size for empirical innovation variance (must be ≥ 2)
    ///
    /// # Errors
    ///
    /// Returns `InvalidParameter` if any noise parameter is ≤ 0, warmup is 0, or
    /// `eiv_window` is less than 2.
    pub fn new(
        q_level: Decimal,
        q_slope: Decimal,
        r_obs: Decimal,
        warmup: usize,
        eiv_window: usize,
    ) -> Result<Self, IndicatorError> {
        if q_level <= Decimal::ZERO {
            return Err(IndicatorError::InvalidParameter {
                message: "q_level must be > 0".to_string(),
            });
        }
        if q_slope <= Decimal::ZERO {
            return Err(IndicatorError::InvalidParameter {
                message: "q_slope must be > 0".to_string(),
            });
        }
        if r_obs <= Decimal::ZERO {
            return Err(IndicatorError::InvalidParameter {
                message: "r_obs must be > 0".to_string(),
            });
        }
        if warmup == 0 {
            return Err(IndicatorError::InvalidParameter {
                message: "warmup must be >= 1".to_string(),
            });
        }
        if eiv_window < 2 {
            return Err(IndicatorError::InvalidParameter {
                message: "eiv_window must be >= 2".to_string(),
            });
        }

        Ok(Self {
            q_level,
            q_slope,
            r_obs,
            warmup,
            eiv_window,
            name: format!(
                "Kalman({},{},{},{},{})",
                q_level, q_slope, r_obs, warmup, eiv_window
            ),
        })
    }

    /// Get the indicator name.
    #[must_use]
    pub fn name(&self) -> &str {
        &self.name
    }

    /// Minimum number of candles consumed before output starts.
    #[must_use]
    pub fn warmup_period(&self) -> usize {
        self.warmup
    }

    /// Compute the Kalman filter on candle data.
    ///
    /// Returns `KalmanResult` with five output series, each of length
    /// `candles.len() - warmup`.
    ///
    /// # Errors
    ///
    /// Returns `InsufficientData` if `candles.len() <= warmup`.
    pub fn compute(&self, candles: &[Candle]) -> Result<KalmanResult, IndicatorError> {
        if candles.len() <= self.warmup {
            return Err(IndicatorError::InsufficientData {
                required: self.warmup + 1,
                actual: candles.len(),
            });
        }

        let n = candles.len();
        let out_len = n - self.warmup;

        // State: [level, slope]
        // Initialize from first candle
        let mut level = candles[0].close();
        let mut slope = Decimal::ZERO;

        // 2x2 covariance matrix P = [[p00, p01], [p10, p11]]
        // Initialize with large uncertainty
        let mut p00 = Decimal::ONE;
        let mut p01 = Decimal::ZERO;
        let mut p10 = Decimal::ZERO;
        let mut p11 = Decimal::ONE;

        let mut levels = Vec::with_capacity(out_len);
        let mut slopes = Vec::with_capacity(out_len);
        let mut innov_vars = Vec::with_capacity(out_len);
        let mut gains = Vec::with_capacity(out_len);
        let mut norm_innovs = Vec::with_capacity(out_len);

        // Ring buffer for rolling innovation variance over eiv_window bars.
        let mut innov_ring: std::collections::VecDeque<Decimal> =
            std::collections::VecDeque::with_capacity(self.eiv_window);
        let mut empirical_var;

        for (i, candle) in candles.iter().enumerate() {
            let z = candle.close();

            // --- Predict ---
            // x_pred = F * x = [level + slope, slope]
            let level_pred = level + slope;
            let slope_pred = slope;

            // --- Innovation (computed before P_pred to drive adaptive Q) ---
            // y = z - H * x_pred = z - level_pred  (H = [1, 0])
            let innovation = z - level_pred;

            // Rolling innovation variance over eiv_window bars.
            // Maintains a ring buffer of raw innovations and computes
            // variance = mean(v²) - mean(v)² over the window.
            innov_ring.push_back(innovation);
            if innov_ring.len() > self.eiv_window {
                innov_ring.pop_front();
            }
            empirical_var = rolling_variance(&innov_ring);

            // Adaptive Q: scale process noise by innovation surprise ratio
            // (Kang 2026). When innovations are large relative to baseline R,
            // process noise inflates → P_pred grows → gain increases.
            let surprise = Decimal::ONE + decimal_div(empirical_var, self.r_obs);
            let q_level_t = self.q_level * surprise;
            let q_slope_t = self.q_slope * surprise;

            // P_pred = F * P * F' + Q_t
            // F = [[1,1],[0,1]], F' = [[1,0],[1,1]]
            // F*P = [[p00+p10, p01+p11],[p10, p11]]
            // (F*P)*F' = [[(p00+p10)+(p01+p11), p01+p11],[p10+p11, p11]]
            let pp00 = p00 + p10 + p01 + p11 + q_level_t;
            let pp01 = p01 + p11;
            let pp10 = p10 + p11;
            let pp11 = p11 + q_slope_t;

            // S = H * P_pred * H' + R = pp00 + R
            let s = pp00 + self.r_obs;

            // --- Kalman gain ---
            // K = P_pred * H' / S = [pp00/S, pp10/S]
            let k0 = decimal_div(pp00, s);
            let k1 = decimal_div(pp10, s);

            // --- Update ---
            // x = x_pred + K * y
            level = level_pred + k0 * innovation;
            slope = slope_pred + k1 * innovation;

            // P = (I - K*H) * P_pred
            // K*H = [[k0, 0],[k1, 0]]
            // I - K*H = [[1-k0, 0],[-k1, 1]]
            // (I - K*H) * P_pred:
            p00 = (Decimal::ONE - k0) * pp00;
            p01 = (Decimal::ONE - k0) * pp01;
            p10 = pp10 - k1 * pp00;
            p11 = pp11 - k1 * pp01;

            // Emit output after warmup
            if i >= self.warmup {
                let ts = candle.timestamp();
                levels.push((ts, level));
                slopes.push((ts, slope));
                innov_vars.push((ts, empirical_var));

                // Clamp gain to [0, 1] for safety
                let gain_clamped = clamp_decimal(k0, Decimal::ZERO, Decimal::ONE);
                gains.push((ts, gain_clamped));

                let norm_innov = normalized_innovation(innovation, empirical_var);
                norm_innovs.push((ts, norm_innov));
            }
        }

        Ok(KalmanResult {
            level: Series::new(levels),
            slope: Series::new(slopes),
            innovation_variance: Series::new(innov_vars),
            kalman_gain: Series::new(gains),
            normalized_innovation: Series::new(norm_innovs),
        })
    }
}

/// Rolling variance of values in a `VecDeque`: `mean(v²) - mean(v)²`.
///
/// Returns `ZERO` when the buffer has fewer than 2 elements.
fn rolling_variance(buf: &std::collections::VecDeque<Decimal>) -> Decimal {
    if buf.len() < 2 {
        return Decimal::ZERO;
    }
    let n = Decimal::from(buf.len() as u64);
    let sum: Decimal = buf.iter().copied().sum();
    let sum_sq: Decimal = buf.iter().map(|v| *v * *v).sum();
    let mean = sum / n;
    let var = sum_sq / n - mean * mean;
    // Clamp to zero in case of floating-point drift
    if var < Decimal::ZERO {
        Decimal::ZERO
    } else {
        var
    }
}

/// Compute normalized innovation: innovation / sqrt(empirical_var).
///
/// Returns ZERO if the variance is non-positive or its square root is zero.
fn normalized_innovation(innovation: Decimal, empirical_var: Decimal) -> Decimal {
    if empirical_var <= Decimal::ZERO {
        return Decimal::ZERO;
    }
    let sqrt_v = decimal_sqrt(empirical_var);
    if sqrt_v <= Decimal::ZERO {
        return Decimal::ZERO;
    }
    decimal_div(innovation, sqrt_v)
}

/// Safe decimal division — returns ZERO on divide-by-zero.
fn decimal_div(num: Decimal, den: Decimal) -> Decimal {
    if den == Decimal::ZERO {
        return Decimal::ZERO;
    }
    num / den
}

/// Clamp a Decimal to [lo, hi].
fn clamp_decimal(val: Decimal, lo: Decimal, hi: Decimal) -> Decimal {
    if val < lo {
        lo
    } else if val > hi {
        hi
    } else {
        val
    }
}

/// Newton-Raphson square root for Decimal.
///
/// Iterates until the estimate changes by less than `epsilon`.
fn decimal_sqrt(value: Decimal) -> Decimal {
    if value <= Decimal::ZERO {
        return Decimal::ZERO;
    }
    if value == Decimal::ONE {
        return Decimal::ONE;
    }

    // Initial guess: value / 2 (works for most ranges)
    let mut guess = value / Decimal::TWO;
    let epsilon = Decimal::new(1, 12); // 1e-12

    for _ in 0..100 {
        if guess <= Decimal::ZERO {
            return Decimal::ZERO;
        }
        let next = (guess + value / guess) / Decimal::TWO;
        let diff = if next > guess {
            next - guess
        } else {
            guess - next
        };
        guess = next;
        if diff < epsilon {
            break;
        }
    }

    guess
}

#[cfg(test)]
#[path = "kalman_tests.rs"]
mod tests;