robotics/localization/
extended_kalman_filter.rs

1use nalgebra::{allocator::Allocator, Const, DefaultAllocator, Dim, OMatrix, OVector, RealField};
2use rustc_hash::FxHashMap;
3
4use crate::models::measurement::MeasurementModel;
5use crate::models::motion::MotionModel;
6use crate::utils::state::GaussianState;
7
8/// S : State Size, Z: Observation Size, U: Input Size
9pub struct ExtendedKalmanFilter<T: RealField, S: Dim, Z: Dim, U: Dim>
10where
11    DefaultAllocator: Allocator<T, S>
12        + Allocator<T, U>
13        + Allocator<T, Z>
14        + Allocator<T, S, S>
15        + Allocator<T, Z, Z>
16        + Allocator<T, Z, S>
17        + Allocator<T, S, U>
18        + Allocator<T, U, U>
19        + Allocator<T, S, Z>
20        + Allocator<T, Const<1>, S>
21        + Allocator<T, Const<1>, Z>,
22{
23    r: OMatrix<T, S, S>,
24    q: OMatrix<T, Z, Z>,
25    measurement_model: Box<dyn MeasurementModel<T, S, Z> + Send>,
26    motion_model: Box<dyn MotionModel<T, S, Z, U> + Send>,
27}
28
29impl<T: RealField, S: Dim, Z: Dim, U: Dim> ExtendedKalmanFilter<T, S, Z, U>
30where
31    DefaultAllocator: Allocator<T, S>
32        + Allocator<T, U>
33        + Allocator<T, Z>
34        + Allocator<T, S, S>
35        + Allocator<T, Z, Z>
36        + Allocator<T, Z, S>
37        + Allocator<T, S, U>
38        + Allocator<T, U, U>
39        + Allocator<T, S, Z>
40        + Allocator<T, Const<1>, S>
41        + Allocator<T, Const<1>, Z>,
42{
43    pub fn new(
44        r: OMatrix<T, S, S>,
45        q: OMatrix<T, Z, Z>,
46        measurement_model: Box<dyn MeasurementModel<T, S, Z> + Send>,
47        motion_model: Box<dyn MotionModel<T, S, Z, U> + Send>,
48    ) -> ExtendedKalmanFilter<T, S, Z, U> {
49        ExtendedKalmanFilter {
50            r,
51            q,
52            measurement_model,
53            motion_model,
54        }
55    }
56
57    pub fn estimate(
58        &self,
59        estimate: &GaussianState<T, S>,
60        u: &OVector<T, U>,
61        z: &OVector<T, Z>,
62        dt: T,
63    ) -> GaussianState<T, S> {
64        // predict
65        let g = self
66            .motion_model
67            .jacobian_wrt_state(&estimate.x, u, dt.clone());
68        let x_pred = self.motion_model.prediction(&estimate.x, u, dt);
69        let cov_pred = &g * &estimate.cov * g.transpose() + &self.r;
70
71        // update
72        let h = self.measurement_model.jacobian(&x_pred, None);
73        let z_pred = self.measurement_model.prediction(&x_pred, None);
74
75        let s = &h * &cov_pred * h.transpose() + &self.q;
76        let kalman_gain = &cov_pred * h.transpose() * s.try_inverse().unwrap();
77        let x_est = &x_pred + &kalman_gain * (z - z_pred);
78        let shape = cov_pred.shape_generic();
79        let cov_est = (OMatrix::identity_generic(shape.0, shape.1) - kalman_gain * h) * &cov_pred;
80        GaussianState {
81            x: x_est,
82            cov: cov_est,
83        }
84    }
85}
86
87/// S : State Size, Z: Observation Size, U: Input Size
88pub struct ExtendedKalmanFilterKnownCorrespondences<T: RealField, S: Dim, Z: Dim, U: Dim>
89where
90    DefaultAllocator: Allocator<T, S>
91        + Allocator<T, U>
92        + Allocator<T, Z>
93        + Allocator<T, S, S>
94        + Allocator<T, Z, Z>
95        + Allocator<T, Z, S>
96        + Allocator<T, S, U>
97        + Allocator<T, U, U>
98        + Allocator<T, S, Z>
99        + Allocator<T, Const<1>, S>
100        + Allocator<T, Const<1>, Z>
101        + Allocator<T, U, S>,
102{
103    r: OMatrix<T, S, S>,
104    q: OMatrix<T, Z, Z>,
105    landmarks: FxHashMap<u32, OVector<T, S>>,
106    measurement_model: Box<dyn MeasurementModel<T, S, Z> + Send>,
107    motion_model: Box<dyn MotionModel<T, S, Z, U> + Send>,
108    fixed_noise: bool,
109}
110
111impl<T: RealField, S: Dim, Z: Dim, U: Dim> ExtendedKalmanFilterKnownCorrespondences<T, S, Z, U>
112where
113    DefaultAllocator: Allocator<T, S>
114        + Allocator<T, U>
115        + Allocator<T, Z>
116        + Allocator<T, S, S>
117        + Allocator<T, Z, Z>
118        + Allocator<T, Z, S>
119        + Allocator<T, S, U>
120        + Allocator<T, U, U>
121        + Allocator<T, S, Z>
122        + Allocator<T, Const<1>, S>
123        + Allocator<T, Const<1>, Z>
124        + Allocator<T, U, S>,
125{
126    pub fn new(
127        r: OMatrix<T, S, S>,
128        q: OMatrix<T, Z, Z>,
129        landmarks: FxHashMap<u32, OVector<T, S>>,
130        measurement_model: Box<dyn MeasurementModel<T, S, Z> + Send>,
131        motion_model: Box<dyn MotionModel<T, S, Z, U> + Send>,
132        fixed_noise: bool,
133    ) -> ExtendedKalmanFilterKnownCorrespondences<T, S, Z, U> {
134        ExtendedKalmanFilterKnownCorrespondences {
135            q,
136            r,
137            landmarks,
138            measurement_model,
139            motion_model,
140            fixed_noise,
141        }
142    }
143
144    pub fn estimate(
145        &self,
146        estimate: &GaussianState<T, S>,
147        control: Option<OVector<T, U>>,
148        measurements: Option<Vec<(u32, OVector<T, Z>)>>,
149        dt: T,
150    ) -> GaussianState<T, S> {
151        let mut x_out = estimate.x.clone();
152        let mut cov_out = estimate.cov.clone();
153        // predict
154        if let Some(u) = control {
155            let g = self
156                .motion_model
157                .jacobian_wrt_state(&estimate.x, &u, dt.clone());
158
159            let x_est = self.motion_model.prediction(&estimate.x, &u, dt.clone());
160            let cov_est = if self.fixed_noise {
161                // fixed version
162                &g * &estimate.cov * g.transpose() + &self.r
163            } else {
164                // adaptive version
165                let v = self.motion_model.jacobian_wrt_input(&estimate.x, &u, dt);
166                let m = self.motion_model.cov_noise_control_space(&u);
167                &g * &estimate.cov * g.transpose() + &v * m * v.transpose()
168            };
169            x_out = x_est;
170            cov_out = cov_est;
171        }
172
173        // update / correction step
174        if let Some(measurements) = measurements {
175            let shape = cov_out.shape_generic();
176            for (id, z) in measurements
177                .iter()
178                .filter(|(id, _)| self.landmarks.contains_key(id))
179            {
180                let landmark = self.landmarks.get(id);
181                let z_pred = self.measurement_model.prediction(&x_out, landmark);
182                let h = self.measurement_model.jacobian(&x_out, landmark);
183                let s = &h * &cov_out * h.transpose() + &self.q;
184                let kalman_gain = &cov_out * h.transpose() * s.try_inverse().unwrap();
185                x_out += &kalman_gain * (z - z_pred);
186                cov_out = (OMatrix::identity_generic(shape.0, shape.1) - kalman_gain * h) * &cov_out
187            }
188        }
189
190        GaussianState {
191            x: x_out,
192            cov: cov_out,
193        }
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use crate::localization::extended_kalman_filter::ExtendedKalmanFilter;
200    use crate::models::measurement::SimpleProblemMeasurementModel;
201    use crate::models::motion::SimpleProblemMotionModel;
202    use crate::utils::deg2rad;
203    use crate::utils::state::GaussianState;
204    use nalgebra::{Const, Matrix4, Vector2, Vector4};
205
206    #[test]
207    fn ekf_runs() {
208        // setup ukf
209        let q = Matrix4::<f64>::from_diagonal(&Vector4::new(0.1, 0.1, deg2rad(1.0), 1.0));
210        let r = nalgebra::Matrix2::identity();
211        let motion_model = SimpleProblemMotionModel::new();
212        let measurement_model = SimpleProblemMeasurementModel::new();
213        let ekf = ExtendedKalmanFilter::<f64, Const<4>, Const<2>, Const<2>>::new(
214            q,
215            r,
216            measurement_model,
217            motion_model,
218        );
219
220        let dt = 0.1;
221        let u: Vector2<f64> = Default::default();
222        let kalman_state = GaussianState {
223            x: Vector4::<f64>::new(0., 0., 0., 0.),
224            cov: Matrix4::<f64>::identity(),
225        };
226        let z: Vector2<f64> = Default::default();
227
228        ekf.estimate(&kalman_state, &u, &z, dt);
229    }
230}