1use nalgebra;
2use std::error::Error;
3use std::fmt;
4
5#[derive(Debug)]
7pub struct KalmanBBoxError {
8 typ: u16,
9}
10impl fmt::Display for KalmanBBoxError {
11 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
12 match self.typ {
13 1 => write!(f, "Can't inverse matrix"),
14 _ => write!(f, "Undefined error"),
15 }
16 }
17}
18impl Error for KalmanBBoxError {}
19
20#[derive(Debug, Clone)]
23pub struct KalmanBBox {
24 dt: f32,
26 u: nalgebra::SMatrix<f32, 4, 1>,
28 std_dev_a: f32,
30 std_dev_m_cx: f32,
32 std_dev_m_cy: f32,
34 std_dev_m_w: f32,
36 std_dev_m_h: f32,
38 A: nalgebra::SMatrix<f32, 8, 8>,
40 B: nalgebra::SMatrix<f32, 8, 4>,
42 H: nalgebra::SMatrix<f32, 4, 8>,
44 Q: nalgebra::SMatrix<f32, 8, 8>,
46 R: nalgebra::SMatrix<f32, 4, 4>,
48 P: nalgebra::SMatrix<f32, 8, 8>,
50 x: nalgebra::SVector<f32, 8>,
52}
53
54impl KalmanBBox {
55 pub fn new(
74 dt: f32,
75 u_cx: f32,
76 u_cy: f32,
77 u_w: f32,
78 u_h: f32,
79 std_dev_a: f32,
80 std_dev_m_cx: f32,
81 std_dev_m_cy: f32,
82 std_dev_m_w: f32,
83 std_dev_m_h: f32,
84 ) -> Self {
85 let dt2 = dt * dt;
86 let dt3 = dt2 * dt;
87 let dt4 = dt3 * dt;
88 let std_dev_a2 = std_dev_a * std_dev_a;
89
90 KalmanBBox {
91 dt,
92 u: nalgebra::SMatrix::<f32, 4, 1>::new(u_cx, u_cy, u_w, u_h),
93 std_dev_a,
94 std_dev_m_cx,
95 std_dev_m_cy,
96 std_dev_m_w,
97 std_dev_m_h,
98 #[rustfmt::skip]
100 A: nalgebra::SMatrix::<f32, 8, 8>::from_row_slice(&[
101 1.0, 0.0, 0.0, 0.0, dt, 0.0, 0.0, 0.0,
102 0.0, 1.0, 0.0, 0.0, 0.0, dt, 0.0, 0.0,
103 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, dt, 0.0,
104 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, dt,
105 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,
106 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
107 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0,
108 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,
109 ]),
110 #[rustfmt::skip]
112 B: nalgebra::SMatrix::<f32, 8, 4>::from_row_slice(&[
113 0.5 * dt2, 0.0, 0.0, 0.0,
114 0.0, 0.5 * dt2, 0.0, 0.0,
115 0.0, 0.0, 0.5 * dt2, 0.0,
116 0.0, 0.0, 0.0, 0.5 * dt2,
117 dt, 0.0, 0.0, 0.0,
118 0.0, dt, 0.0, 0.0,
119 0.0, 0.0, dt, 0.0,
120 0.0, 0.0, 0.0, dt,
121 ]),
122 #[rustfmt::skip]
124 H: nalgebra::SMatrix::<f32, 4, 8>::from_row_slice(&[
125 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
126 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
127 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0,
128 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0,
129 ]),
130 #[rustfmt::skip]
132 Q: nalgebra::SMatrix::<f32, 8, 8>::from_row_slice(&[
133 0.25 * dt4 * std_dev_a2, 0.0, 0.0, 0.0, 0.5 * dt3 * std_dev_a2, 0.0, 0.0, 0.0,
134 0.0, 0.25 * dt4 * std_dev_a2, 0.0, 0.0, 0.0, 0.5 * dt3 * std_dev_a2, 0.0, 0.0,
135 0.0, 0.0, 0.25 * dt4 * std_dev_a2, 0.0, 0.0, 0.0, 0.5 * dt3 * std_dev_a2, 0.0,
136 0.0, 0.0, 0.0, 0.25 * dt4 * std_dev_a2, 0.0, 0.0, 0.0, 0.5 * dt3 * std_dev_a2,
137 0.5 * dt3 * std_dev_a2, 0.0, 0.0, 0.0, dt2 * std_dev_a2, 0.0, 0.0, 0.0,
138 0.0, 0.5 * dt3 * std_dev_a2, 0.0, 0.0, 0.0, dt2 * std_dev_a2, 0.0, 0.0,
139 0.0, 0.0, 0.5 * dt3 * std_dev_a2, 0.0, 0.0, 0.0, dt2 * std_dev_a2, 0.0,
140 0.0, 0.0, 0.0, 0.5 * dt3 * std_dev_a2, 0.0, 0.0, 0.0, dt2 * std_dev_a2,
141 ]),
142 R: nalgebra::SMatrix::<f32, 4, 4>::new(
144 std_dev_m_cx * std_dev_m_cx, 0.0, 0.0, 0.0,
145 0.0, std_dev_m_cy * std_dev_m_cy, 0.0, 0.0,
146 0.0, 0.0, std_dev_m_w * std_dev_m_w, 0.0,
147 0.0, 0.0, 0.0, std_dev_m_h * std_dev_m_h,
148 ),
149 #[rustfmt::skip]
150 P: nalgebra::SMatrix::<f32, 8, 8>::from_row_slice(&[
151 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
152 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
153 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0,
154 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0,
155 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,
156 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
157 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0,
158 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,
159 ]),
160 #[rustfmt::skip]
161 x: nalgebra::SVector::<f32, 8>::from_row_slice(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]),
162 }
163 }
164
165 pub fn new_with_state(
194 dt: f32,
195 u_cx: f32,
196 u_cy: f32,
197 u_w: f32,
198 u_h: f32,
199 std_dev_a: f32,
200 std_dev_m_cx: f32,
201 std_dev_m_cy: f32,
202 std_dev_m_w: f32,
203 std_dev_m_h: f32,
204 cx: f32,
205 cy: f32,
206 w: f32,
207 h: f32,
208 ) -> Self {
209 let dt2 = dt * dt;
210 let dt3 = dt2 * dt;
211 let dt4 = dt3 * dt;
212 let std_dev_a2 = std_dev_a * std_dev_a;
213
214 KalmanBBox {
215 dt,
216 u: nalgebra::SMatrix::<f32, 4, 1>::new(u_cx, u_cy, u_w, u_h),
217 std_dev_a,
218 std_dev_m_cx,
219 std_dev_m_cy,
220 std_dev_m_w,
221 std_dev_m_h,
222 #[rustfmt::skip]
224 A: nalgebra::SMatrix::<f32, 8, 8>::from_row_slice(&[
225 1.0, 0.0, 0.0, 0.0, dt, 0.0, 0.0, 0.0,
226 0.0, 1.0, 0.0, 0.0, 0.0, dt, 0.0, 0.0,
227 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, dt, 0.0,
228 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, dt,
229 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,
230 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
231 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0,
232 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,
233 ]),
234 #[rustfmt::skip]
236 B: nalgebra::SMatrix::<f32, 8, 4>::from_row_slice(&[
237 0.5 * dt2, 0.0, 0.0, 0.0,
238 0.0, 0.5 * dt2, 0.0, 0.0,
239 0.0, 0.0, 0.5 * dt2, 0.0,
240 0.0, 0.0, 0.0, 0.5 * dt2,
241 dt, 0.0, 0.0, 0.0,
242 0.0, dt, 0.0, 0.0,
243 0.0, 0.0, dt, 0.0,
244 0.0, 0.0, 0.0, dt,
245 ]),
246 #[rustfmt::skip]
248 H: nalgebra::SMatrix::<f32, 4, 8>::from_row_slice(&[
249 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
250 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
251 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0,
252 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0,
253 ]),
254 #[rustfmt::skip]
256 Q: nalgebra::SMatrix::<f32, 8, 8>::from_row_slice(&[
257 0.25 * dt4 * std_dev_a2, 0.0, 0.0, 0.0, 0.5 * dt3 * std_dev_a2, 0.0, 0.0, 0.0,
258 0.0, 0.25 * dt4 * std_dev_a2, 0.0, 0.0, 0.0, 0.5 * dt3 * std_dev_a2, 0.0, 0.0,
259 0.0, 0.0, 0.25 * dt4 * std_dev_a2, 0.0, 0.0, 0.0, 0.5 * dt3 * std_dev_a2, 0.0,
260 0.0, 0.0, 0.0, 0.25 * dt4 * std_dev_a2, 0.0, 0.0, 0.0, 0.5 * dt3 * std_dev_a2,
261 0.5 * dt3 * std_dev_a2, 0.0, 0.0, 0.0, dt2 * std_dev_a2, 0.0, 0.0, 0.0,
262 0.0, 0.5 * dt3 * std_dev_a2, 0.0, 0.0, 0.0, dt2 * std_dev_a2, 0.0, 0.0,
263 0.0, 0.0, 0.5 * dt3 * std_dev_a2, 0.0, 0.0, 0.0, dt2 * std_dev_a2, 0.0,
264 0.0, 0.0, 0.0, 0.5 * dt3 * std_dev_a2, 0.0, 0.0, 0.0, dt2 * std_dev_a2,
265 ]),
266 R: nalgebra::SMatrix::<f32, 4, 4>::new(
268 std_dev_m_cx * std_dev_m_cx, 0.0, 0.0, 0.0,
269 0.0, std_dev_m_cy * std_dev_m_cy, 0.0, 0.0,
270 0.0, 0.0, std_dev_m_w * std_dev_m_w, 0.0,
271 0.0, 0.0, 0.0, std_dev_m_h * std_dev_m_h,
272 ),
273 #[rustfmt::skip]
274 P: nalgebra::SMatrix::<f32, 8, 8>::from_row_slice(&[
275 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
276 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
277 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0,
278 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0,
279 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,
280 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
281 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0,
282 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,
283 ]),
284 #[rustfmt::skip]
285 x: nalgebra::SVector::<f32, 8>::from_row_slice(&[cx, cy, w, h, 0.0, 0.0, 0.0, 0.0]),
286 }
287 }
288
289 pub fn predict(&mut self) {
306 self.x = (self.A * self.x) + (self.B * self.u);
308 self.P = self.A * self.P * self.A.transpose() + self.Q;
310 }
311
312 pub fn update(&mut self, _z_cx: f32, _z_cy: f32, _z_w: f32, _z_h: f32) -> Result<(), KalmanBBoxError> {
328 let S = self.H * self.P * self.H.transpose() + self.R;
330 let gain = match S.try_inverse() {
331 Some(inv) => self.P * self.H.transpose() * inv,
332 None => return Err(KalmanBBoxError { typ: 1 }),
333 };
334 let z = nalgebra::SMatrix::<f32, 4, 1>::new(_z_cx, _z_cy, _z_w, _z_h);
336 let r = z - self.H * self.x;
337 self.x = self.x + gain * r;
339 #[rustfmt::skip]
342 let I: nalgebra::SMatrix<f32, 8, 8> = nalgebra::SMatrix::<f32, 8, 8>::from_row_slice(&[
343 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
344 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
345 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0,
346 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0,
347 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,
348 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
349 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0,
350 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,
351 ]);
352 self.P = (I - gain * self.H) * self.P;
353 Ok(())
354 }
355
356 pub fn get_state(&self) -> (f32, f32, f32, f32) {
358 (self.x[0], self.x[1], self.x[2], self.x[3])
359 }
360
361 pub fn get_velocity(&self) -> (f32, f32, f32, f32) {
363 (self.x[4], self.x[5], self.x[6], self.x[7])
364 }
365
366 pub fn get_predicted_state(&self) -> (f32, f32, f32, f32) {
368 let x_pred = (self.A * self.x) + (self.B * self.u);
369 (x_pred[0], x_pred[1], x_pred[2], x_pred[3])
370 }
371
372 pub fn get_position_uncertainty(&self) -> f32 {
374 (self.P[(0, 0)].powi(2) + self.P[(1, 1)].powi(2)).sqrt()
375 }
376
377 pub fn get_vector_state(&self) -> nalgebra::SVector<f32, 8> {
379 self.x
380 }
381
382 pub fn mahalanobis_distance_squared(&self, z_cx: f32, z_cy: f32, z_w: f32, z_h: f32) -> Result<f32, KalmanBBoxError> {
389 let z = nalgebra::SMatrix::<f32, 4, 1>::new(z_cx, z_cy, z_w, z_h);
391 let y = z - self.H * self.x;
392
393 let S = self.H * self.P * self.H.transpose() + self.R;
395
396 let S_inv = match S.try_inverse() {
398 Some(inv) => inv,
399 None => return Err(KalmanBBoxError { typ: 1 }),
400 };
401
402 let d_squared = (y.transpose() * S_inv * y)[(0, 0)];
403 Ok(d_squared)
404 }
405
406 pub fn mahalanobis_distance(&self, z_cx: f32, z_cy: f32, z_w: f32, z_h: f32) -> Result<f32, KalmanBBoxError> {
409 let d_squared = self.mahalanobis_distance_squared(z_cx, z_cy, z_w, z_h)?;
410 Ok(d_squared.sqrt())
411 }
412
413 pub fn get_innovation_covariance(&self) -> nalgebra::SMatrix<f32, 4, 4> {
418 self.H * self.P * self.H.transpose() + self.R
419 }
420
421 pub fn gating_check(&self, z_cx: f32, z_cy: f32, z_w: f32, z_h: f32, threshold: f32) -> bool {
428 match self.mahalanobis_distance_squared(z_cx, z_cy, z_w, z_h) {
429 Ok(d_sq) => d_sq < threshold,
430 Err(_) => false,
431 }
432 }
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438
439 #[test]
440 fn test_bbox_kalman() {
441 let dt = 0.04; let u_cx = 1.0;
443 let u_cy = 1.0;
444 let u_w = 0.0; let u_h = 0.0;
446 let std_dev_a = 2.0;
447 let std_dev_m_cx = 0.1;
448 let std_dev_m_cy = 0.1;
449 let std_dev_m_w = 0.1;
450 let std_dev_m_h = 0.1;
451
452 let xs = vec![
455 311, 312, 313, 311, 311, 312, 312, 313, 312, 312, 312, 312, 312, 312, 312, 312, 312,
456 312, 311, 311, 311, 311, 311, 310, 311, 311, 311, 310, 310, 308, 307, 308, 308, 308,
457 307, 307, 307, 308, 307, 307, 307, 307, 307, 308, 307, 309, 306, 307, 306, 307, 308,
458 306, 306, 306, 305, 307, 307, 307, 306, 306, 306, 307, 307, 308, 307, 307, 308, 307,
459 306, 308, 309, 309, 309, 309, 308, 309, 309, 309, 308, 311, 311, 307, 311, 307, 313,
460 311, 307, 311, 311, 306, 312, 312, 312, 312, 312, 312, 312, 312, 312, 312, 312, 312,
461 312, 312, 312, 312, 312, 312, 312, 312, 312, 312,
462 ];
463 let ys = vec![
464 5, 6, 8, 10, 11, 12, 12, 13, 16, 16, 18, 18, 19, 19, 20, 20, 22, 22, 23, 23, 24, 24,
465 28, 30, 32, 35, 39, 42, 44, 46, 56, 58, 70, 60, 52, 64, 51, 70, 70, 70, 66, 83, 80, 85,
466 80, 98, 79, 98, 61, 94, 101, 94, 104, 94, 107, 112, 108, 108, 109, 109, 121, 108, 108,
467 120, 122, 122, 128, 130, 122, 140, 122, 122, 140, 122, 134, 141, 136, 136, 154, 155,
468 155, 150, 161, 162, 169, 171, 181, 175, 175, 163, 178, 178, 178, 178, 178, 178, 178,
469 178, 178, 178, 178, 178, 178, 178, 178, 178, 178, 178, 178, 178, 178, 178,
470 ];
471
472 let base_w = 40.0_f32;
475 let base_h = 80.0_f32;
476
477 let ws: Vec<f32> = xs.iter().enumerate().map(|(i, _)| {
479 base_w + (i as f32 * 0.1).sin() * 5.0
481 }).collect();
482
483 let hs: Vec<f32> = ys.iter().enumerate().map(|(i, _)| {
484 base_h + (i as f32 * 0.15).sin() * 8.0
486 }).collect();
487
488 let i_cx = xs[0] as f32;
490 let i_cy = ys[0] as f32;
491 let i_w = ws[0];
492 let i_h = hs[0];
493
494 let mut kalman = KalmanBBox::new_with_state(
495 dt, u_cx, u_cy, u_w, u_h,
496 std_dev_a,
497 std_dev_m_cx, std_dev_m_cy, std_dev_m_w, std_dev_m_h,
498 i_cx, i_cy, i_w, i_h
499 );
500
501 let mut predictions: Vec<(f32, f32, f32, f32)> = vec![];
502 let mut updated_states: Vec<(f32, f32, f32, f32)> = vec![];
503
504 for i in 0..xs.len() {
505 let m_cx = xs[i] as f32;
506 let m_cy = ys[i] as f32;
507 let m_w = ws[i];
508 let m_h = hs[i];
509
510 kalman.predict();
512 predictions.push(kalman.get_state());
513
514 kalman.update(m_cx, m_cy, m_w, m_h).unwrap();
516 updated_states.push(kalman.get_state());
517 }
518
519 assert!(predictions.len() == xs.len());
521 assert!(updated_states.len() == xs.len());
522
523 let (final_cx, final_cy, final_w, final_h) = updated_states.last().unwrap();
525 let last_m_cx = *xs.last().unwrap() as f32;
526 let last_m_cy = *ys.last().unwrap() as f32;
527 let last_m_w = *ws.last().unwrap();
528 let last_m_h = *hs.last().unwrap();
529
530 assert!((final_cx - last_m_cx).abs() < 5.0, "Final cx too far from measurement");
531 assert!((final_cy - last_m_cy).abs() < 5.0, "Final cy too far from measurement");
532 assert!((final_w - last_m_w).abs() < 5.0, "Final w too far from measurement");
533 assert!((final_h - last_m_h).abs() < 5.0, "Final h too far from measurement");
534
535 }
546
547 #[test]
548 fn test_mahalanobis_distance() {
549 let dt = 0.04;
550 let mut kalman = KalmanBBox::new_with_state(
551 dt, 1.0, 1.0, 0.0, 0.0,
552 2.0, 0.1, 0.1, 0.1, 0.1,
553 100.0, 50.0, 40.0, 80.0
554 );
555
556 kalman.predict();
558
559 let d_close = kalman.mahalanobis_distance_squared(100.5, 50.5, 40.0, 80.0).unwrap();
561
562 let d_far = kalman.mahalanobis_distance_squared(200.0, 150.0, 60.0, 120.0).unwrap();
564
565 assert!(d_close < d_far, "Close measurement should have smaller distance");
566
567 let threshold = 9.488; assert!(kalman.gating_check(100.5, 50.5, 40.0, 80.0, threshold), "Close measurement should pass gate");
570 }
571
572 #[test]
573 fn test_velocity_tracking() {
574 let dt = 0.04;
575 let mut kalman = KalmanBBox::new_with_state(
576 dt, 0.0, 0.0, 0.0, 0.0, 2.0, 0.1, 0.1, 0.1, 0.1,
578 100.0, 50.0, 40.0, 80.0
579 );
580
581 for i in 1..20 {
583 let cx = 100.0 + i as f32 * 2.0; let cy = 50.0 + i as f32 * 1.5; let w = 40.0 + i as f32 * 0.5; let h = 80.0 + i as f32 * 0.3; kalman.predict();
589 kalman.update(cx, cy, w, h).unwrap();
590 }
591
592 let (vx, vy, vw, vh) = kalman.get_velocity();
594
595 assert!(vx > 0.0, "Velocity X should be positive");
597 assert!(vy > 0.0, "Velocity Y should be positive");
598 assert!(vw > 0.0, "Velocity W should be positive");
599 assert!(vh > 0.0, "Velocity H should be positive");
600 }
601}