Skip to main content

oxicuda_seq/kalman/
ekf.rs

1//! Extended Kalman Filter with closure-provided non-linear dynamics and Jacobians.
2
3use super::linalg::{add, eye, inverse, matmul, matmul_rect, sub, transpose_rect};
4use crate::error::{SeqError, SeqResult};
5
6/// Result of running the EKF on a sequence.
7#[derive(Debug, Clone)]
8pub struct ExtendedKalmanResult {
9    pub means: Vec<Vec<f64>>,
10    pub covs: Vec<Vec<f64>>,
11}
12
13/// Extended Kalman filter parameterised on non-linear `f(x)`, `h(x)`, plus
14/// Jacobians `Fj(x)` and `Hj(x)`.
15pub struct ExtendedKalmanFilter<'a> {
16    pub dim_x: usize,
17    pub dim_z: usize,
18    pub f: Box<dyn Fn(&[f64]) -> Vec<f64> + 'a>,
19    pub h: Box<dyn Fn(&[f64]) -> Vec<f64> + 'a>,
20    pub f_jacobian: Box<dyn Fn(&[f64]) -> Vec<f64> + 'a>,
21    pub h_jacobian: Box<dyn Fn(&[f64]) -> Vec<f64> + 'a>,
22    pub q: Vec<f64>,
23    pub r: Vec<f64>,
24    pub x0: Vec<f64>,
25    pub p0: Vec<f64>,
26}
27
28impl<'a> ExtendedKalmanFilter<'a> {
29    /// Run the EKF on `z` (T × dim_z).
30    pub fn run(&self, z: &[f64]) -> SeqResult<ExtendedKalmanResult> {
31        if z.is_empty() {
32            return Err(SeqError::EmptyInput);
33        }
34        if z.len() % self.dim_z != 0 {
35            return Err(SeqError::DimensionMismatch {
36                a: z.len(),
37                b: self.dim_z,
38            });
39        }
40        let t_max = z.len() / self.dim_z;
41        let nx = self.dim_x;
42        let nz = self.dim_z;
43        let i_eye = eye(nx);
44        let mut means = Vec::with_capacity(t_max);
45        let mut covs = Vec::with_capacity(t_max);
46        let mut x = self.x0.clone();
47        let mut p = self.p0.clone();
48        for t in 0..t_max {
49            // Predict
50            let x_pred = (self.f)(&x);
51            let f_j = (self.f_jacobian)(&x);
52            let f_j_t = transpose_rect(&f_j, nx, nx);
53            let fp = matmul(&f_j, &p, nx);
54            let fpft = matmul_rect(&fp, &f_j_t, nx, nx, nx);
55            let p_pred = add(&fpft, &self.q);
56
57            // Update
58            let z_t = &z[t * nz..(t + 1) * nz];
59            let h_x = (self.h)(&x_pred);
60            let mut y = vec![0.0; nz];
61            for k in 0..nz {
62                y[k] = z_t[k] - h_x[k];
63            }
64            let h_j = (self.h_jacobian)(&x_pred);
65            let h_j_t = transpose_rect(&h_j, nz, nx);
66            let hp = matmul_rect(&h_j, &p_pred, nz, nx, nx);
67            let hpht = matmul_rect(&hp, &h_j_t, nz, nx, nz);
68            let s = add(&hpht, &self.r);
69            let s_inv = inverse(&s, nz)?;
70            let pht = matmul_rect(&p_pred, &h_j_t, nx, nx, nz);
71            let k_gain = matmul_rect(&pht, &s_inv, nx, nz, nz);
72            let ky = matmul_rect(&k_gain, &y, nx, nz, 1);
73            for d in 0..nx {
74                x[d] = x_pred[d] + ky[d];
75            }
76            let kh = matmul_rect(&k_gain, &h_j, nx, nz, nx);
77            let i_kh = sub(&i_eye, &kh);
78            p = matmul_rect(&i_kh, &p_pred, nx, nx, nx);
79            means.push(x.clone());
80            covs.push(p.clone());
81        }
82        Ok(ExtendedKalmanResult { means, covs })
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89
90    #[test]
91    fn ekf_linear_matches_kf() {
92        let dim_x = 1;
93        let dim_z = 1;
94        let ekf = ExtendedKalmanFilter {
95            dim_x,
96            dim_z,
97            f: Box::new(|x: &[f64]| vec![x[0]]),
98            h: Box::new(|x: &[f64]| vec![x[0]]),
99            f_jacobian: Box::new(|_x: &[f64]| vec![1.0]),
100            h_jacobian: Box::new(|_x: &[f64]| vec![1.0]),
101            q: vec![0.01],
102            r: vec![0.05],
103            x0: vec![0.0],
104            p0: vec![1.0],
105        };
106        let z = vec![1.0, 1.05, 0.95];
107        let res = ekf.run(&z).expect("ok");
108        assert_eq!(res.means.len(), 3);
109        assert!((res.means[res.means.len() - 1][0] - 1.0).abs() < 0.2);
110    }
111}