oxicuda_seq/kalman/
ekf.rs1use super::linalg::{add, eye, inverse, matmul, matmul_rect, sub, transpose_rect};
4use crate::error::{SeqError, SeqResult};
5
6#[derive(Debug, Clone)]
8pub struct ExtendedKalmanResult {
9 pub means: Vec<Vec<f64>>,
10 pub covs: Vec<Vec<f64>>,
11}
12
13pub struct ExtendedKalmanFilter<'a> {
16 pub dim_x: usize,
17 pub dim_z: usize,
18 pub f: Box<dyn Fn(&[f64]) -> Vec<f64> + 'a>,
19 pub h: Box<dyn Fn(&[f64]) -> Vec<f64> + 'a>,
20 pub f_jacobian: Box<dyn Fn(&[f64]) -> Vec<f64> + 'a>,
21 pub h_jacobian: Box<dyn Fn(&[f64]) -> Vec<f64> + 'a>,
22 pub q: Vec<f64>,
23 pub r: Vec<f64>,
24 pub x0: Vec<f64>,
25 pub p0: Vec<f64>,
26}
27
28impl<'a> ExtendedKalmanFilter<'a> {
29 pub fn run(&self, z: &[f64]) -> SeqResult<ExtendedKalmanResult> {
31 if z.is_empty() {
32 return Err(SeqError::EmptyInput);
33 }
34 if z.len() % self.dim_z != 0 {
35 return Err(SeqError::DimensionMismatch {
36 a: z.len(),
37 b: self.dim_z,
38 });
39 }
40 let t_max = z.len() / self.dim_z;
41 let nx = self.dim_x;
42 let nz = self.dim_z;
43 let i_eye = eye(nx);
44 let mut means = Vec::with_capacity(t_max);
45 let mut covs = Vec::with_capacity(t_max);
46 let mut x = self.x0.clone();
47 let mut p = self.p0.clone();
48 for t in 0..t_max {
49 let x_pred = (self.f)(&x);
51 let f_j = (self.f_jacobian)(&x);
52 let f_j_t = transpose_rect(&f_j, nx, nx);
53 let fp = matmul(&f_j, &p, nx);
54 let fpft = matmul_rect(&fp, &f_j_t, nx, nx, nx);
55 let p_pred = add(&fpft, &self.q);
56
57 let z_t = &z[t * nz..(t + 1) * nz];
59 let h_x = (self.h)(&x_pred);
60 let mut y = vec![0.0; nz];
61 for k in 0..nz {
62 y[k] = z_t[k] - h_x[k];
63 }
64 let h_j = (self.h_jacobian)(&x_pred);
65 let h_j_t = transpose_rect(&h_j, nz, nx);
66 let hp = matmul_rect(&h_j, &p_pred, nz, nx, nx);
67 let hpht = matmul_rect(&hp, &h_j_t, nz, nx, nz);
68 let s = add(&hpht, &self.r);
69 let s_inv = inverse(&s, nz)?;
70 let pht = matmul_rect(&p_pred, &h_j_t, nx, nx, nz);
71 let k_gain = matmul_rect(&pht, &s_inv, nx, nz, nz);
72 let ky = matmul_rect(&k_gain, &y, nx, nz, 1);
73 for d in 0..nx {
74 x[d] = x_pred[d] + ky[d];
75 }
76 let kh = matmul_rect(&k_gain, &h_j, nx, nz, nx);
77 let i_kh = sub(&i_eye, &kh);
78 p = matmul_rect(&i_kh, &p_pred, nx, nx, nx);
79 means.push(x.clone());
80 covs.push(p.clone());
81 }
82 Ok(ExtendedKalmanResult { means, covs })
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89
90 #[test]
91 fn ekf_linear_matches_kf() {
92 let dim_x = 1;
93 let dim_z = 1;
94 let ekf = ExtendedKalmanFilter {
95 dim_x,
96 dim_z,
97 f: Box::new(|x: &[f64]| vec![x[0]]),
98 h: Box::new(|x: &[f64]| vec![x[0]]),
99 f_jacobian: Box::new(|_x: &[f64]| vec![1.0]),
100 h_jacobian: Box::new(|_x: &[f64]| vec![1.0]),
101 q: vec![0.01],
102 r: vec![0.05],
103 x0: vec![0.0],
104 p0: vec![1.0],
105 };
106 let z = vec![1.0, 1.05, 0.95];
107 let res = ekf.run(&z).expect("ok");
108 assert_eq!(res.means.len(), 3);
109 assert!((res.means[res.means.len() - 1][0] - 1.0).abs() < 0.2);
110 }
111}