Skip to main content

quant_indicators/
kalman.rs

1//! Kalman filter indicator — Local Linear Trend model.
2//!
3//! Produces five output series from a single state-space model:
4//! - `level` — filtered price level
5//! - `slope` — trend direction/strength
6//! - `innovation_variance` — prediction error variance (regime signal)
7//! - `kalman_gain` — adaptation speed (level component)
8//! - `normalized_innovation` — innovation / sqrt(innovation_variance)
9//!
10//! # State-Space Model
11//!
12//! ```text
13//! State:       x_t = [level_t, slope_t]'
14//! Transition:  x_{t+1} = [[1,1],[0,1]] * x_t + eta_t,  eta ~ N(0, Q)
15//! Observation: y_t = [1, 0] * x_t + eps_t,  eps ~ N(0, R)
16//! ```
17//!
18//! # References
19//!
20//! - Levine & Pedersen (AQR, 2016): all linear trend filters are mathematically equivalent
21//! - Benhamou 2018: Sharpe 1.22, whipsaw reduction
22//! - Kang 2026: vol-scaled R_t for regime adaptation
23
24use quant_primitives::Candle;
25use rust_decimal::Decimal;
26
27use crate::error::IndicatorError;
28use crate::series::Series;
29
30/// Kalman filter result containing all five output series.
31#[derive(Debug, Clone, PartialEq, Eq)]
32pub struct KalmanResult {
33    /// Filtered price level.
34    pub level: Series,
35    /// Trend direction and strength.
36    pub slope: Series,
37    /// Prediction error variance — low = trending, high = choppy.
38    pub innovation_variance: Series,
39    /// Adaptation speed (level component of Kalman gain, bounded [0, 1]).
40    pub kalman_gain: Series,
41    /// Innovation divided by sqrt(innovation_variance) — regime separation signal.
42    pub normalized_innovation: Series,
43}
44
45impl KalmanResult {
46    /// Number of output values (post-warmup).
47    #[must_use]
48    pub fn len(&self) -> usize {
49        self.level.len()
50    }
51
52    /// Whether the result is empty.
53    #[must_use]
54    pub fn is_empty(&self) -> bool {
55        self.level.is_empty()
56    }
57}
58
59/// Kalman filter with a Local Linear Trend state-space model.
60///
61/// # Example
62///
63/// ```
64/// use quant_indicators::KalmanFilter;
65/// use quant_primitives::Candle;
66/// use chrono::Utc;
67/// use rust_decimal_macros::dec;
68///
69/// let ts = Utc::now();
70/// let candles: Vec<Candle> = (0..100).map(|i| {
71///     let price = dec!(100) + rust_decimal::Decimal::from(i);
72///     Candle::new(price, price, price, price, dec!(1000), ts).unwrap()
73/// }).collect();
74/// let kf = KalmanFilter::new(dec!(0.01), dec!(0.001), dec!(1), 20, 20).unwrap();
75/// let result = kf.compute(&candles).unwrap();
76/// assert_eq!(result.len(), 80); // 100 - 20
77/// ```
78#[derive(Debug, Clone)]
79pub struct KalmanFilter {
80    q_level: Decimal,
81    q_slope: Decimal,
82    r_obs: Decimal,
83    warmup: usize,
84    eiv_window: usize,
85    name: String,
86}
87
88impl KalmanFilter {
89    /// Create a new Kalman filter.
90    ///
91    /// # Arguments
92    ///
93    /// * `q_level` — process noise for level state (must be > 0)
94    /// * `q_slope` — process noise for slope state (must be > 0)
95    /// * `r_obs` — measurement noise (must be > 0)
96    /// * `warmup` — bars consumed before output starts (must be ≥ 1)
97    /// * `eiv_window` — rolling window size for empirical innovation variance (must be ≥ 2)
98    ///
99    /// # Errors
100    ///
101    /// Returns `InvalidParameter` if any noise parameter is ≤ 0, warmup is 0, or
102    /// `eiv_window` is less than 2.
103    pub fn new(
104        q_level: Decimal,
105        q_slope: Decimal,
106        r_obs: Decimal,
107        warmup: usize,
108        eiv_window: usize,
109    ) -> Result<Self, IndicatorError> {
110        if q_level <= Decimal::ZERO {
111            return Err(IndicatorError::InvalidParameter {
112                message: "q_level must be > 0".to_string(),
113            });
114        }
115        if q_slope <= Decimal::ZERO {
116            return Err(IndicatorError::InvalidParameter {
117                message: "q_slope must be > 0".to_string(),
118            });
119        }
120        if r_obs <= Decimal::ZERO {
121            return Err(IndicatorError::InvalidParameter {
122                message: "r_obs must be > 0".to_string(),
123            });
124        }
125        if warmup == 0 {
126            return Err(IndicatorError::InvalidParameter {
127                message: "warmup must be >= 1".to_string(),
128            });
129        }
130        if eiv_window < 2 {
131            return Err(IndicatorError::InvalidParameter {
132                message: "eiv_window must be >= 2".to_string(),
133            });
134        }
135
136        Ok(Self {
137            q_level,
138            q_slope,
139            r_obs,
140            warmup,
141            eiv_window,
142            name: format!(
143                "Kalman({},{},{},{},{})",
144                q_level, q_slope, r_obs, warmup, eiv_window
145            ),
146        })
147    }
148
149    /// Get the indicator name.
150    #[must_use]
151    pub fn name(&self) -> &str {
152        &self.name
153    }
154
155    /// Minimum number of candles consumed before output starts.
156    #[must_use]
157    pub fn warmup_period(&self) -> usize {
158        self.warmup
159    }
160
161    /// Compute the Kalman filter on candle data.
162    ///
163    /// Returns `KalmanResult` with five output series, each of length
164    /// `candles.len() - warmup`.
165    ///
166    /// # Errors
167    ///
168    /// Returns `InsufficientData` if `candles.len() <= warmup`.
169    pub fn compute(&self, candles: &[Candle]) -> Result<KalmanResult, IndicatorError> {
170        if candles.len() <= self.warmup {
171            return Err(IndicatorError::InsufficientData {
172                required: self.warmup + 1,
173                actual: candles.len(),
174            });
175        }
176
177        let n = candles.len();
178        let out_len = n - self.warmup;
179
180        // State: [level, slope]
181        // Initialize from first candle
182        let mut level = candles[0].close();
183        let mut slope = Decimal::ZERO;
184
185        // 2x2 covariance matrix P = [[p00, p01], [p10, p11]]
186        // Initialize with large uncertainty
187        let mut p00 = Decimal::ONE;
188        let mut p01 = Decimal::ZERO;
189        let mut p10 = Decimal::ZERO;
190        let mut p11 = Decimal::ONE;
191
192        let mut levels = Vec::with_capacity(out_len);
193        let mut slopes = Vec::with_capacity(out_len);
194        let mut innov_vars = Vec::with_capacity(out_len);
195        let mut gains = Vec::with_capacity(out_len);
196        let mut norm_innovs = Vec::with_capacity(out_len);
197
198        // Ring buffer for rolling innovation variance over eiv_window bars.
199        let mut innov_ring: std::collections::VecDeque<Decimal> =
200            std::collections::VecDeque::with_capacity(self.eiv_window);
201        let mut empirical_var;
202
203        for (i, candle) in candles.iter().enumerate() {
204            let z = candle.close();
205
206            // --- Predict ---
207            // x_pred = F * x = [level + slope, slope]
208            let level_pred = level + slope;
209            let slope_pred = slope;
210
211            // --- Innovation (computed before P_pred to drive adaptive Q) ---
212            // y = z - H * x_pred = z - level_pred  (H = [1, 0])
213            let innovation = z - level_pred;
214
215            // Rolling innovation variance over eiv_window bars.
216            // Maintains a ring buffer of raw innovations and computes
217            // variance = mean(v²) - mean(v)² over the window.
218            innov_ring.push_back(innovation);
219            if innov_ring.len() > self.eiv_window {
220                innov_ring.pop_front();
221            }
222            empirical_var = rolling_variance(&innov_ring);
223
224            // Adaptive Q: scale process noise by innovation surprise ratio
225            // (Kang 2026). When innovations are large relative to baseline R,
226            // process noise inflates → P_pred grows → gain increases.
227            let surprise = Decimal::ONE + decimal_div(empirical_var, self.r_obs);
228            let q_level_t = self.q_level * surprise;
229            let q_slope_t = self.q_slope * surprise;
230
231            // P_pred = F * P * F' + Q_t
232            // F = [[1,1],[0,1]], F' = [[1,0],[1,1]]
233            // F*P = [[p00+p10, p01+p11],[p10, p11]]
234            // (F*P)*F' = [[(p00+p10)+(p01+p11), p01+p11],[p10+p11, p11]]
235            let pp00 = p00 + p10 + p01 + p11 + q_level_t;
236            let pp01 = p01 + p11;
237            let pp10 = p10 + p11;
238            let pp11 = p11 + q_slope_t;
239
240            // S = H * P_pred * H' + R = pp00 + R
241            let s = pp00 + self.r_obs;
242
243            // --- Kalman gain ---
244            // K = P_pred * H' / S = [pp00/S, pp10/S]
245            let k0 = decimal_div(pp00, s);
246            let k1 = decimal_div(pp10, s);
247
248            // --- Update ---
249            // x = x_pred + K * y
250            level = level_pred + k0 * innovation;
251            slope = slope_pred + k1 * innovation;
252
253            // P = (I - K*H) * P_pred
254            // K*H = [[k0, 0],[k1, 0]]
255            // I - K*H = [[1-k0, 0],[-k1, 1]]
256            // (I - K*H) * P_pred:
257            p00 = (Decimal::ONE - k0) * pp00;
258            p01 = (Decimal::ONE - k0) * pp01;
259            p10 = pp10 - k1 * pp00;
260            p11 = pp11 - k1 * pp01;
261
262            // Emit output after warmup
263            if i >= self.warmup {
264                let ts = candle.timestamp();
265                levels.push((ts, level));
266                slopes.push((ts, slope));
267                innov_vars.push((ts, empirical_var));
268
269                // Clamp gain to [0, 1] for safety
270                let gain_clamped = clamp_decimal(k0, Decimal::ZERO, Decimal::ONE);
271                gains.push((ts, gain_clamped));
272
273                let norm_innov = normalized_innovation(innovation, empirical_var);
274                norm_innovs.push((ts, norm_innov));
275            }
276        }
277
278        Ok(KalmanResult {
279            level: Series::new(levels),
280            slope: Series::new(slopes),
281            innovation_variance: Series::new(innov_vars),
282            kalman_gain: Series::new(gains),
283            normalized_innovation: Series::new(norm_innovs),
284        })
285    }
286}
287
288/// Rolling variance of values in a `VecDeque`: `mean(v²) - mean(v)²`.
289///
290/// Returns `ZERO` when the buffer has fewer than 2 elements.
291fn rolling_variance(buf: &std::collections::VecDeque<Decimal>) -> Decimal {
292    if buf.len() < 2 {
293        return Decimal::ZERO;
294    }
295    let n = Decimal::from(buf.len() as u64);
296    let sum: Decimal = buf.iter().copied().sum();
297    let sum_sq: Decimal = buf.iter().map(|v| *v * *v).sum();
298    let mean = sum / n;
299    let var = sum_sq / n - mean * mean;
300    // Clamp to zero in case of floating-point drift
301    if var < Decimal::ZERO {
302        Decimal::ZERO
303    } else {
304        var
305    }
306}
307
308/// Compute normalized innovation: innovation / sqrt(empirical_var).
309///
310/// Returns ZERO if the variance is non-positive or its square root is zero.
311fn normalized_innovation(innovation: Decimal, empirical_var: Decimal) -> Decimal {
312    if empirical_var <= Decimal::ZERO {
313        return Decimal::ZERO;
314    }
315    let sqrt_v = decimal_sqrt(empirical_var);
316    if sqrt_v <= Decimal::ZERO {
317        return Decimal::ZERO;
318    }
319    decimal_div(innovation, sqrt_v)
320}
321
322/// Safe decimal division — returns ZERO on divide-by-zero.
323fn decimal_div(num: Decimal, den: Decimal) -> Decimal {
324    if den == Decimal::ZERO {
325        return Decimal::ZERO;
326    }
327    num / den
328}
329
330/// Clamp a Decimal to [lo, hi].
331fn clamp_decimal(val: Decimal, lo: Decimal, hi: Decimal) -> Decimal {
332    if val < lo {
333        lo
334    } else if val > hi {
335        hi
336    } else {
337        val
338    }
339}
340
341/// Newton-Raphson square root for Decimal.
342///
343/// Iterates until the estimate changes by less than `epsilon`.
344fn decimal_sqrt(value: Decimal) -> Decimal {
345    if value <= Decimal::ZERO {
346        return Decimal::ZERO;
347    }
348    if value == Decimal::ONE {
349        return Decimal::ONE;
350    }
351
352    // Initial guess: value / 2 (works for most ranges)
353    let mut guess = value / Decimal::TWO;
354    let epsilon = Decimal::new(1, 12); // 1e-12
355
356    for _ in 0..100 {
357        if guess <= Decimal::ZERO {
358            return Decimal::ZERO;
359        }
360        let next = (guess + value / guess) / Decimal::TWO;
361        let diff = if next > guess {
362            next - guess
363        } else {
364            guess - next
365        };
366        guess = next;
367        if diff < epsilon {
368            break;
369        }
370    }
371
372    guess
373}
374
375#[cfg(test)]
376#[path = "kalman_tests.rs"]
377mod tests;