Skip to main content

quantwave_core/indicators/
kalman.rs

1use crate::indicators::metadata::{IndicatorMetadata, ParamDef};
2use crate::traits::Next;
3
4/// Kalman Filter
5///
6/// A 1D Kalman filter used for adaptive smoothing.
7/// It models the true price as a hidden state and updates its estimate
8/// based on the process noise (Q) and measurement noise (R).
9#[derive(Debug, Clone)]
10pub struct KalmanFilter {
11    q: f64,
12    r: f64,
13    x: f64,
14    p: f64,
15    k: f64,
16    initialized: bool,
17}
18
19impl KalmanFilter {
20    pub fn new(q: f64, r: f64) -> Self {
21        Self {
22            q,
23            r,
24            x: 0.0,
25            p: 1.0,
26            k: 0.0,
27            initialized: false,
28        }
29    }
30}
31
32impl Next<f64> for KalmanFilter {
33    type Output = f64;
34
35    fn next(&mut self, input: f64) -> Self::Output {
36        if !self.initialized {
37            self.x = input;
38            self.p = 1.0;
39            self.initialized = true;
40            return self.x;
41        }
42
43        // Prediction
44        // x = x (constant position model)
45        // p = p + q
46        let p_pred = self.p + self.q;
47
48        // Update
49        // k = p_pred / (p_pred + r)
50        self.k = p_pred / (p_pred + self.r);
51        
52        // x = x + k * (input - x)
53        self.x = self.x + self.k * (input - self.x);
54        
55        // p = (1 - k) * p_pred
56        self.p = (1.0 - self.k) * p_pred;
57
58        self.x
59    }
60}
61
62pub const KALMAN_FILTER_METADATA: IndicatorMetadata = IndicatorMetadata {
63    name: "Kalman Filter",
64    description: "An adaptive 1D Kalman filter for smoothing price data with minimal lag.",
65    usage: "Use as a highly responsive alternative to moving averages. The Q parameter (process noise) controls responsiveness to trend changes, while R (measurement noise) controls smoothness. Higher Q makes it track price faster; higher R increases smoothing.",
66    keywords: &["filter", "adaptive", "smoothing", "ml", "kalman"],
67    ehlers_summary: "The Kalman Filter is an optimal estimator for linear systems with Gaussian noise. In technical analysis, the 1D version recursively updates the estimate of the 'true' price by balancing the predicted state against new measurements. It is particularly effective for feature engineering in ML models due to its ability to separate signal from noise dynamically.",
68    params: &[
69        ParamDef {
70            name: "q",
71            default: "0.01",
72            description: "Process noise (responsiveness)",
73        },
74        ParamDef {
75            name: "r",
76            default: "0.1",
77            description: "Measurement noise (smoothing)",
78        },
79    ],
80    formula_source: "https://en.wikipedia.org/wiki/Kalman_filter",
81    formula_latex: r#"
82\[
83P_{t|t-1} = P_{t-1} + Q
84\]
85\[
86K_t = \frac{P_{t|t-1}}{P_{t|t-1} + R}
87\]
88\[
89X_t = X_{t-1} + K_t(Z_t - X_{t-1})
90\]
91\[
92P_t = (1 - K_t)P_{t|t-1}
93\]
94"#,
95    gold_standard_file: "kalman_filter.json",
96    category: "ML Features",
97};
98
99#[cfg(test)]
100mod tests {
101    use super::*;
102    use crate::traits::Next;
103    use proptest::prelude::*;
104
105    #[test]
106    fn test_kalman_basic() {
107        let mut kf = KalmanFilter::new(0.01, 0.1);
108        let res = kf.next(100.0);
109        assert_eq!(res, 100.0); // First value
110        let res2 = kf.next(101.0);
111        assert!(res2 > 100.0 && res2 < 101.0);
112    }
113
114    proptest! {
115        #[test]
116        fn test_kalman_parity(
117            inputs in prop::collection::vec(1.0..100.0, 50..100),
118        ) {
119            let q = 0.01;
120            let r = 0.1;
121            let mut kf = KalmanFilter::new(q, r);
122            let streaming_results: Vec<f64> = inputs.iter().map(|&x| kf.next(x)).collect();
123
124            // Reference implementation
125            let mut batch_results = Vec::with_capacity(inputs.len());
126            let mut x = inputs[0];
127            let mut p = 1.0;
128            batch_results.push(x);
129
130            for i in 1..inputs.len() {
131                let p_pred = p + q;
132                let k = p_pred / (p_pred + r);
133                x = x + k * (inputs[i] - x);
134                p = (1.0 - k) * p_pred;
135                batch_results.push(x);
136            }
137
138            for (s, b) in streaming_results.iter().zip(batch_results.iter()) {
139                approx::assert_relative_eq!(s, b, epsilon = 1e-10);
140            }
141        }
142    }
143}