quant_indicators/kalman.rs
1//! Kalman filter indicator — Local Linear Trend model.
2//!
3//! Produces five output series from a single state-space model:
4//! - `level` — filtered price level
5//! - `slope` — trend direction/strength
6//! - `innovation_variance` — prediction error variance (regime signal)
7//! - `kalman_gain` — adaptation speed (level component)
8//! - `normalized_innovation` — innovation / sqrt(innovation_variance)
9//!
10//! # State-Space Model
11//!
12//! ```text
13//! State: x_t = [level_t, slope_t]'
14//! Transition: x_{t+1} = [[1,1],[0,1]] * x_t + eta_t, eta ~ N(0, Q)
15//! Observation: y_t = [1, 0] * x_t + eps_t, eps ~ N(0, R)
16//! ```
17//!
18//! # References
19//!
20//! - Levine & Pedersen (AQR, 2016): all linear trend filters are mathematically equivalent
21//! - Benhamou 2018: Sharpe 1.22, whipsaw reduction
22//! - Kang 2026: vol-scaled R_t for regime adaptation
23
24use quant_primitives::Candle;
25use rust_decimal::Decimal;
26
27use crate::error::IndicatorError;
28use crate::series::Series;
29
30/// Kalman filter result containing all five output series.
31#[derive(Debug, Clone, PartialEq, Eq)]
32pub struct KalmanResult {
33 /// Filtered price level.
34 pub level: Series,
35 /// Trend direction and strength.
36 pub slope: Series,
37 /// Prediction error variance — low = trending, high = choppy.
38 pub innovation_variance: Series,
39 /// Adaptation speed (level component of Kalman gain, bounded [0, 1]).
40 pub kalman_gain: Series,
41 /// Innovation divided by sqrt(innovation_variance) — regime separation signal.
42 pub normalized_innovation: Series,
43}
44
45impl KalmanResult {
46 /// Number of output values (post-warmup).
47 #[must_use]
48 pub fn len(&self) -> usize {
49 self.level.len()
50 }
51
52 /// Whether the result is empty.
53 #[must_use]
54 pub fn is_empty(&self) -> bool {
55 self.level.is_empty()
56 }
57}
58
59/// Kalman filter with a Local Linear Trend state-space model.
60///
61/// # Example
62///
63/// ```
64/// use quant_indicators::KalmanFilter;
65/// use quant_primitives::Candle;
66/// use chrono::Utc;
67/// use rust_decimal_macros::dec;
68///
69/// let ts = Utc::now();
70/// let candles: Vec<Candle> = (0..100).map(|i| {
71/// let price = dec!(100) + rust_decimal::Decimal::from(i);
72/// Candle::new(price, price, price, price, dec!(1000), ts).unwrap()
73/// }).collect();
74/// let kf = KalmanFilter::new(dec!(0.01), dec!(0.001), dec!(1), 20, 20).unwrap();
75/// let result = kf.compute(&candles).unwrap();
76/// assert_eq!(result.len(), 80); // 100 - 20
77/// ```
78#[derive(Debug, Clone)]
79pub struct KalmanFilter {
80 q_level: Decimal,
81 q_slope: Decimal,
82 r_obs: Decimal,
83 warmup: usize,
84 eiv_window: usize,
85 name: String,
86}
87
88impl KalmanFilter {
89 /// Create a new Kalman filter.
90 ///
91 /// # Arguments
92 ///
93 /// * `q_level` — process noise for level state (must be > 0)
94 /// * `q_slope` — process noise for slope state (must be > 0)
95 /// * `r_obs` — measurement noise (must be > 0)
96 /// * `warmup` — bars consumed before output starts (must be ≥ 1)
97 /// * `eiv_window` — rolling window size for empirical innovation variance (must be ≥ 2)
98 ///
99 /// # Errors
100 ///
101 /// Returns `InvalidParameter` if any noise parameter is ≤ 0, warmup is 0, or
102 /// `eiv_window` is less than 2.
103 pub fn new(
104 q_level: Decimal,
105 q_slope: Decimal,
106 r_obs: Decimal,
107 warmup: usize,
108 eiv_window: usize,
109 ) -> Result<Self, IndicatorError> {
110 if q_level <= Decimal::ZERO {
111 return Err(IndicatorError::InvalidParameter {
112 message: "q_level must be > 0".to_string(),
113 });
114 }
115 if q_slope <= Decimal::ZERO {
116 return Err(IndicatorError::InvalidParameter {
117 message: "q_slope must be > 0".to_string(),
118 });
119 }
120 if r_obs <= Decimal::ZERO {
121 return Err(IndicatorError::InvalidParameter {
122 message: "r_obs must be > 0".to_string(),
123 });
124 }
125 if warmup == 0 {
126 return Err(IndicatorError::InvalidParameter {
127 message: "warmup must be >= 1".to_string(),
128 });
129 }
130 if eiv_window < 2 {
131 return Err(IndicatorError::InvalidParameter {
132 message: "eiv_window must be >= 2".to_string(),
133 });
134 }
135
136 Ok(Self {
137 q_level,
138 q_slope,
139 r_obs,
140 warmup,
141 eiv_window,
142 name: format!(
143 "Kalman({},{},{},{},{})",
144 q_level, q_slope, r_obs, warmup, eiv_window
145 ),
146 })
147 }
148
149 /// Get the indicator name.
150 #[must_use]
151 pub fn name(&self) -> &str {
152 &self.name
153 }
154
155 /// Minimum number of candles consumed before output starts.
156 #[must_use]
157 pub fn warmup_period(&self) -> usize {
158 self.warmup
159 }
160
161 /// Compute the Kalman filter on candle data.
162 ///
163 /// Returns `KalmanResult` with five output series, each of length
164 /// `candles.len() - warmup`.
165 ///
166 /// # Errors
167 ///
168 /// Returns `InsufficientData` if `candles.len() <= warmup`.
169 pub fn compute(&self, candles: &[Candle]) -> Result<KalmanResult, IndicatorError> {
170 if candles.len() <= self.warmup {
171 return Err(IndicatorError::InsufficientData {
172 required: self.warmup + 1,
173 actual: candles.len(),
174 });
175 }
176
177 let n = candles.len();
178 let out_len = n - self.warmup;
179
180 // State: [level, slope]
181 // Initialize from first candle
182 let mut level = candles[0].close();
183 let mut slope = Decimal::ZERO;
184
185 // 2x2 covariance matrix P = [[p00, p01], [p10, p11]]
186 // Initialize with large uncertainty
187 let mut p00 = Decimal::ONE;
188 let mut p01 = Decimal::ZERO;
189 let mut p10 = Decimal::ZERO;
190 let mut p11 = Decimal::ONE;
191
192 let mut levels = Vec::with_capacity(out_len);
193 let mut slopes = Vec::with_capacity(out_len);
194 let mut innov_vars = Vec::with_capacity(out_len);
195 let mut gains = Vec::with_capacity(out_len);
196 let mut norm_innovs = Vec::with_capacity(out_len);
197
198 // Ring buffer for rolling innovation variance over eiv_window bars.
199 let mut innov_ring: std::collections::VecDeque<Decimal> =
200 std::collections::VecDeque::with_capacity(self.eiv_window);
201 let mut empirical_var;
202
203 for (i, candle) in candles.iter().enumerate() {
204 let z = candle.close();
205
206 // --- Predict ---
207 // x_pred = F * x = [level + slope, slope]
208 let level_pred = level + slope;
209 let slope_pred = slope;
210
211 // --- Innovation (computed before P_pred to drive adaptive Q) ---
212 // y = z - H * x_pred = z - level_pred (H = [1, 0])
213 let innovation = z - level_pred;
214
215 // Rolling innovation variance over eiv_window bars.
216 // Maintains a ring buffer of raw innovations and computes
217 // variance = mean(v²) - mean(v)² over the window.
218 innov_ring.push_back(innovation);
219 if innov_ring.len() > self.eiv_window {
220 innov_ring.pop_front();
221 }
222 empirical_var = rolling_variance(&innov_ring);
223
224 // Adaptive Q: scale process noise by innovation surprise ratio
225 // (Kang 2026). When innovations are large relative to baseline R,
226 // process noise inflates → P_pred grows → gain increases.
227 let surprise = Decimal::ONE + decimal_div(empirical_var, self.r_obs);
228 let q_level_t = self.q_level * surprise;
229 let q_slope_t = self.q_slope * surprise;
230
231 // P_pred = F * P * F' + Q_t
232 // F = [[1,1],[0,1]], F' = [[1,0],[1,1]]
233 // F*P = [[p00+p10, p01+p11],[p10, p11]]
234 // (F*P)*F' = [[(p00+p10)+(p01+p11), p01+p11],[p10+p11, p11]]
235 let pp00 = p00 + p10 + p01 + p11 + q_level_t;
236 let pp01 = p01 + p11;
237 let pp10 = p10 + p11;
238 let pp11 = p11 + q_slope_t;
239
240 // S = H * P_pred * H' + R = pp00 + R
241 let s = pp00 + self.r_obs;
242
243 // --- Kalman gain ---
244 // K = P_pred * H' / S = [pp00/S, pp10/S]
245 let k0 = decimal_div(pp00, s);
246 let k1 = decimal_div(pp10, s);
247
248 // --- Update ---
249 // x = x_pred + K * y
250 level = level_pred + k0 * innovation;
251 slope = slope_pred + k1 * innovation;
252
253 // P = (I - K*H) * P_pred
254 // K*H = [[k0, 0],[k1, 0]]
255 // I - K*H = [[1-k0, 0],[-k1, 1]]
256 // (I - K*H) * P_pred:
257 p00 = (Decimal::ONE - k0) * pp00;
258 p01 = (Decimal::ONE - k0) * pp01;
259 p10 = pp10 - k1 * pp00;
260 p11 = pp11 - k1 * pp01;
261
262 // Emit output after warmup
263 if i >= self.warmup {
264 let ts = candle.timestamp();
265 levels.push((ts, level));
266 slopes.push((ts, slope));
267 innov_vars.push((ts, empirical_var));
268
269 // Clamp gain to [0, 1] for safety
270 let gain_clamped = clamp_decimal(k0, Decimal::ZERO, Decimal::ONE);
271 gains.push((ts, gain_clamped));
272
273 let norm_innov = normalized_innovation(innovation, empirical_var);
274 norm_innovs.push((ts, norm_innov));
275 }
276 }
277
278 Ok(KalmanResult {
279 level: Series::new(levels),
280 slope: Series::new(slopes),
281 innovation_variance: Series::new(innov_vars),
282 kalman_gain: Series::new(gains),
283 normalized_innovation: Series::new(norm_innovs),
284 })
285 }
286}
287
288/// Rolling variance of values in a `VecDeque`: `mean(v²) - mean(v)²`.
289///
290/// Returns `ZERO` when the buffer has fewer than 2 elements.
291fn rolling_variance(buf: &std::collections::VecDeque<Decimal>) -> Decimal {
292 if buf.len() < 2 {
293 return Decimal::ZERO;
294 }
295 let n = Decimal::from(buf.len() as u64);
296 let sum: Decimal = buf.iter().copied().sum();
297 let sum_sq: Decimal = buf.iter().map(|v| *v * *v).sum();
298 let mean = sum / n;
299 let var = sum_sq / n - mean * mean;
300 // Clamp to zero in case of floating-point drift
301 if var < Decimal::ZERO {
302 Decimal::ZERO
303 } else {
304 var
305 }
306}
307
308/// Compute normalized innovation: innovation / sqrt(empirical_var).
309///
310/// Returns ZERO if the variance is non-positive or its square root is zero.
311fn normalized_innovation(innovation: Decimal, empirical_var: Decimal) -> Decimal {
312 if empirical_var <= Decimal::ZERO {
313 return Decimal::ZERO;
314 }
315 let sqrt_v = decimal_sqrt(empirical_var);
316 if sqrt_v <= Decimal::ZERO {
317 return Decimal::ZERO;
318 }
319 decimal_div(innovation, sqrt_v)
320}
321
322/// Safe decimal division — returns ZERO on divide-by-zero.
323fn decimal_div(num: Decimal, den: Decimal) -> Decimal {
324 if den == Decimal::ZERO {
325 return Decimal::ZERO;
326 }
327 num / den
328}
329
330/// Clamp a Decimal to [lo, hi].
331fn clamp_decimal(val: Decimal, lo: Decimal, hi: Decimal) -> Decimal {
332 if val < lo {
333 lo
334 } else if val > hi {
335 hi
336 } else {
337 val
338 }
339}
340
341/// Newton-Raphson square root for Decimal.
342///
343/// Iterates until the estimate changes by less than `epsilon`.
344fn decimal_sqrt(value: Decimal) -> Decimal {
345 if value <= Decimal::ZERO {
346 return Decimal::ZERO;
347 }
348 if value == Decimal::ONE {
349 return Decimal::ONE;
350 }
351
352 // Initial guess: value / 2 (works for most ranges)
353 let mut guess = value / Decimal::TWO;
354 let epsilon = Decimal::new(1, 12); // 1e-12
355
356 for _ in 0..100 {
357 if guess <= Decimal::ZERO {
358 return Decimal::ZERO;
359 }
360 let next = (guess + value / guess) / Decimal::TWO;
361 let diff = if next > guess {
362 next - guess
363 } else {
364 guess - next
365 };
366 guess = next;
367 if diff < epsilon {
368 break;
369 }
370 }
371
372 guess
373}
374
375#[cfg(test)]
376#[path = "kalman_tests.rs"]
377mod tests;