1use nalgebra::{allocator::Allocator, Const, DefaultAllocator, Dim, OMatrix, OVector, RealField};
2use rustc_hash::FxHashMap;
3
4use crate::models::measurement::MeasurementModel;
5use crate::models::motion::MotionModel;
6use crate::utils::state::GaussianState;
7
8pub struct ExtendedKalmanFilter<T: RealField, S: Dim, Z: Dim, U: Dim>
10where
11 DefaultAllocator: Allocator<T, S>
12 + Allocator<T, U>
13 + Allocator<T, Z>
14 + Allocator<T, S, S>
15 + Allocator<T, Z, Z>
16 + Allocator<T, Z, S>
17 + Allocator<T, S, U>
18 + Allocator<T, U, U>
19 + Allocator<T, S, Z>
20 + Allocator<T, Const<1>, S>
21 + Allocator<T, Const<1>, Z>,
22{
23 r: OMatrix<T, S, S>,
24 q: OMatrix<T, Z, Z>,
25 measurement_model: Box<dyn MeasurementModel<T, S, Z> + Send>,
26 motion_model: Box<dyn MotionModel<T, S, Z, U> + Send>,
27}
28
29impl<T: RealField, S: Dim, Z: Dim, U: Dim> ExtendedKalmanFilter<T, S, Z, U>
30where
31 DefaultAllocator: Allocator<T, S>
32 + Allocator<T, U>
33 + Allocator<T, Z>
34 + Allocator<T, S, S>
35 + Allocator<T, Z, Z>
36 + Allocator<T, Z, S>
37 + Allocator<T, S, U>
38 + Allocator<T, U, U>
39 + Allocator<T, S, Z>
40 + Allocator<T, Const<1>, S>
41 + Allocator<T, Const<1>, Z>,
42{
43 pub fn new(
44 r: OMatrix<T, S, S>,
45 q: OMatrix<T, Z, Z>,
46 measurement_model: Box<dyn MeasurementModel<T, S, Z> + Send>,
47 motion_model: Box<dyn MotionModel<T, S, Z, U> + Send>,
48 ) -> ExtendedKalmanFilter<T, S, Z, U> {
49 ExtendedKalmanFilter {
50 r,
51 q,
52 measurement_model,
53 motion_model,
54 }
55 }
56
57 pub fn estimate(
58 &self,
59 estimate: &GaussianState<T, S>,
60 u: &OVector<T, U>,
61 z: &OVector<T, Z>,
62 dt: T,
63 ) -> GaussianState<T, S> {
64 let g = self
66 .motion_model
67 .jacobian_wrt_state(&estimate.x, u, dt.clone());
68 let x_pred = self.motion_model.prediction(&estimate.x, u, dt);
69 let cov_pred = &g * &estimate.cov * g.transpose() + &self.r;
70
71 let h = self.measurement_model.jacobian(&x_pred, None);
73 let z_pred = self.measurement_model.prediction(&x_pred, None);
74
75 let s = &h * &cov_pred * h.transpose() + &self.q;
76 let kalman_gain = &cov_pred * h.transpose() * s.try_inverse().unwrap();
77 let x_est = &x_pred + &kalman_gain * (z - z_pred);
78 let shape = cov_pred.shape_generic();
79 let cov_est = (OMatrix::identity_generic(shape.0, shape.1) - kalman_gain * h) * &cov_pred;
80 GaussianState {
81 x: x_est,
82 cov: cov_est,
83 }
84 }
85}
86
87pub struct ExtendedKalmanFilterKnownCorrespondences<T: RealField, S: Dim, Z: Dim, U: Dim>
89where
90 DefaultAllocator: Allocator<T, S>
91 + Allocator<T, U>
92 + Allocator<T, Z>
93 + Allocator<T, S, S>
94 + Allocator<T, Z, Z>
95 + Allocator<T, Z, S>
96 + Allocator<T, S, U>
97 + Allocator<T, U, U>
98 + Allocator<T, S, Z>
99 + Allocator<T, Const<1>, S>
100 + Allocator<T, Const<1>, Z>
101 + Allocator<T, U, S>,
102{
103 r: OMatrix<T, S, S>,
104 q: OMatrix<T, Z, Z>,
105 landmarks: FxHashMap<u32, OVector<T, S>>,
106 measurement_model: Box<dyn MeasurementModel<T, S, Z> + Send>,
107 motion_model: Box<dyn MotionModel<T, S, Z, U> + Send>,
108 fixed_noise: bool,
109}
110
111impl<T: RealField, S: Dim, Z: Dim, U: Dim> ExtendedKalmanFilterKnownCorrespondences<T, S, Z, U>
112where
113 DefaultAllocator: Allocator<T, S>
114 + Allocator<T, U>
115 + Allocator<T, Z>
116 + Allocator<T, S, S>
117 + Allocator<T, Z, Z>
118 + Allocator<T, Z, S>
119 + Allocator<T, S, U>
120 + Allocator<T, U, U>
121 + Allocator<T, S, Z>
122 + Allocator<T, Const<1>, S>
123 + Allocator<T, Const<1>, Z>
124 + Allocator<T, U, S>,
125{
126 pub fn new(
127 r: OMatrix<T, S, S>,
128 q: OMatrix<T, Z, Z>,
129 landmarks: FxHashMap<u32, OVector<T, S>>,
130 measurement_model: Box<dyn MeasurementModel<T, S, Z> + Send>,
131 motion_model: Box<dyn MotionModel<T, S, Z, U> + Send>,
132 fixed_noise: bool,
133 ) -> ExtendedKalmanFilterKnownCorrespondences<T, S, Z, U> {
134 ExtendedKalmanFilterKnownCorrespondences {
135 q,
136 r,
137 landmarks,
138 measurement_model,
139 motion_model,
140 fixed_noise,
141 }
142 }
143
144 pub fn estimate(
145 &self,
146 estimate: &GaussianState<T, S>,
147 control: Option<OVector<T, U>>,
148 measurements: Option<Vec<(u32, OVector<T, Z>)>>,
149 dt: T,
150 ) -> GaussianState<T, S> {
151 let mut x_out = estimate.x.clone();
152 let mut cov_out = estimate.cov.clone();
153 if let Some(u) = control {
155 let g = self
156 .motion_model
157 .jacobian_wrt_state(&estimate.x, &u, dt.clone());
158
159 let x_est = self.motion_model.prediction(&estimate.x, &u, dt.clone());
160 let cov_est = if self.fixed_noise {
161 &g * &estimate.cov * g.transpose() + &self.r
163 } else {
164 let v = self.motion_model.jacobian_wrt_input(&estimate.x, &u, dt);
166 let m = self.motion_model.cov_noise_control_space(&u);
167 &g * &estimate.cov * g.transpose() + &v * m * v.transpose()
168 };
169 x_out = x_est;
170 cov_out = cov_est;
171 }
172
173 if let Some(measurements) = measurements {
175 let shape = cov_out.shape_generic();
176 for (id, z) in measurements
177 .iter()
178 .filter(|(id, _)| self.landmarks.contains_key(id))
179 {
180 let landmark = self.landmarks.get(id);
181 let z_pred = self.measurement_model.prediction(&x_out, landmark);
182 let h = self.measurement_model.jacobian(&x_out, landmark);
183 let s = &h * &cov_out * h.transpose() + &self.q;
184 let kalman_gain = &cov_out * h.transpose() * s.try_inverse().unwrap();
185 x_out += &kalman_gain * (z - z_pred);
186 cov_out = (OMatrix::identity_generic(shape.0, shape.1) - kalman_gain * h) * &cov_out
187 }
188 }
189
190 GaussianState {
191 x: x_out,
192 cov: cov_out,
193 }
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use crate::localization::extended_kalman_filter::ExtendedKalmanFilter;
200 use crate::models::measurement::SimpleProblemMeasurementModel;
201 use crate::models::motion::SimpleProblemMotionModel;
202 use crate::utils::deg2rad;
203 use crate::utils::state::GaussianState;
204 use nalgebra::{Const, Matrix4, Vector2, Vector4};
205
206 #[test]
207 fn ekf_runs() {
208 let q = Matrix4::<f64>::from_diagonal(&Vector4::new(0.1, 0.1, deg2rad(1.0), 1.0));
210 let r = nalgebra::Matrix2::identity();
211 let motion_model = SimpleProblemMotionModel::new();
212 let measurement_model = SimpleProblemMeasurementModel::new();
213 let ekf = ExtendedKalmanFilter::<f64, Const<4>, Const<2>, Const<2>>::new(
214 q,
215 r,
216 measurement_model,
217 motion_model,
218 );
219
220 let dt = 0.1;
221 let u: Vector2<f64> = Default::default();
222 let kalman_state = GaussianState {
223 x: Vector4::<f64>::new(0., 0., 0., 0.),
224 cov: Matrix4::<f64>::identity(),
225 };
226 let z: Vector2<f64> = Default::default();
227
228 ekf.estimate(&kalman_state, &u, &z, dt);
229 }
230}