kalman_rust/kalman/
kalman_bbox.rs

1use nalgebra;
2use std::error::Error;
3use std::fmt;
4
5// Error struct for failed `nalgebra` operations
6#[derive(Debug)]
7pub struct KalmanBBoxError {
8    typ: u16,
9}
10impl fmt::Display for KalmanBBoxError {
11    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
12        match self.typ {
13            1 => write!(f, "Can't inverse matrix"),
14            _ => write!(f, "Undefined error"),
15        }
16    }
17}
18impl Error for KalmanBBoxError {}
19
20/// Implementation of Discrete Kalman filter for bounding box tracking.
21/// State vector contains: center_x, center_y, width, height, velocity_cx, velocity_cy, velocity_w, velocity_h
22#[derive(Debug, Clone)]
23pub struct KalmanBBox {
24    // Single cycle time
25    dt: f32,
26    // Control input (accelerations for cx, cy, w, h)
27    u: nalgebra::SMatrix<f32, 4, 1>,
28    // Standart deviation of acceleration
29    std_dev_a: f32,
30    // Standart deviation of measurement for center X
31    std_dev_m_cx: f32,
32    // Standart deviation of measurement for center Y
33    std_dev_m_cy: f32,
34    // Standart deviation of measurement for width
35    std_dev_m_w: f32,
36    // Standart deviation of measurement for height
37    std_dev_m_h: f32,
38    // Transition matrix
39    A: nalgebra::SMatrix<f32, 8, 8>,
40    // Control matrix
41    B: nalgebra::SMatrix<f32, 8, 4>,
42    // Transformation (observation) matrix
43    H: nalgebra::SMatrix<f32, 4, 8>,
44    // Process noise covariance matrix
45    Q: nalgebra::SMatrix<f32, 8, 8>,
46    // Measurement noise covariance matrix
47    R: nalgebra::SMatrix<f32, 4, 4>,
48    // Error covariance matrix
49    P: nalgebra::SMatrix<f32, 8, 8>,
50    // State vector: cx, cy, w, h, vx, vy, vw, vh
51    x: nalgebra::SVector<f32, 8>,
52}
53
54impl KalmanBBox {
55    /// Creates new `KalmanBBox`
56    ///
57    /// Basic usage:
58    ///
59    /// ```
60    /// use kalman_rust::kalman::KalmanBBox;
61    /// let dt = 0.04; // Single cycle time (1/25 fps)
62    /// let u_cx = 1.0; // Control input for center X
63    /// let u_cy = 1.0; // Control input for center Y
64    /// let u_w = 0.0; // Control input for width
65    /// let u_h = 0.0; // Control input for height
66    /// let std_dev_a = 2.0; // Standart deviation of acceleration
67    /// let std_dev_m_cx = 0.1; // Standart deviation of measurement for center X
68    /// let std_dev_m_cy = 0.1; // Standart deviation of measurement for center Y
69    /// let std_dev_m_w = 0.1; // Standart deviation of measurement for width
70    /// let std_dev_m_h = 0.1; // Standart deviation of measurement for height
71    /// let mut kalman = KalmanBBox::new(dt, u_cx, u_cy, u_w, u_h, std_dev_a, std_dev_m_cx, std_dev_m_cy, std_dev_m_w, std_dev_m_h);
72    /// ```
73    pub fn new(
74        dt: f32,
75        u_cx: f32,
76        u_cy: f32,
77        u_w: f32,
78        u_h: f32,
79        std_dev_a: f32,
80        std_dev_m_cx: f32,
81        std_dev_m_cy: f32,
82        std_dev_m_w: f32,
83        std_dev_m_h: f32,
84    ) -> Self {
85        let dt2 = dt * dt;
86        let dt3 = dt2 * dt;
87        let dt4 = dt3 * dt;
88        let std_dev_a2 = std_dev_a * std_dev_a;
89
90        KalmanBBox {
91            dt,
92            u: nalgebra::SMatrix::<f32, 4, 1>::new(u_cx, u_cy, u_w, u_h),
93            std_dev_a,
94            std_dev_m_cx,
95            std_dev_m_cy,
96            std_dev_m_w,
97            std_dev_m_h,
98            // Ref.: Eq.(53)
99            #[rustfmt::skip]
100            A: nalgebra::SMatrix::<f32, 8, 8>::from_row_slice(&[
101                1.0, 0.0, 0.0, 0.0, dt, 0.0, 0.0, 0.0,
102                0.0, 1.0, 0.0, 0.0, 0.0, dt, 0.0, 0.0,
103                0.0, 0.0, 1.0, 0.0, 0.0, 0.0, dt, 0.0,
104                0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, dt,
105                0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,
106                0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
107                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0,
108                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,
109            ]),
110            // Ref.: Eq.(54)
111            #[rustfmt::skip]
112            B: nalgebra::SMatrix::<f32, 8, 4>::from_row_slice(&[
113                0.5 * dt2, 0.0, 0.0, 0.0,
114                0.0, 0.5 * dt2, 0.0, 0.0,
115                0.0, 0.0, 0.5 * dt2, 0.0,
116                0.0, 0.0, 0.0, 0.5 * dt2,
117                dt, 0.0, 0.0, 0.0,
118                0.0, dt, 0.0, 0.0,
119                0.0, 0.0, dt, 0.0,
120                0.0, 0.0, 0.0, dt,
121            ]),
122            // Ref.: Eq.(56)
123            #[rustfmt::skip]
124            H: nalgebra::SMatrix::<f32, 4, 8>::from_row_slice(&[
125                1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
126                0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
127                0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0,
128                0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0,
129            ]),
130            // Ref.: Eq.(62)
131            #[rustfmt::skip]
132            Q: nalgebra::SMatrix::<f32, 8, 8>::from_row_slice(&[
133                0.25 * dt4 * std_dev_a2, 0.0, 0.0, 0.0, 0.5 * dt3 * std_dev_a2, 0.0, 0.0, 0.0,
134                0.0, 0.25 * dt4 * std_dev_a2, 0.0, 0.0, 0.0, 0.5 * dt3 * std_dev_a2, 0.0, 0.0,
135                0.0, 0.0, 0.25 * dt4 * std_dev_a2, 0.0, 0.0, 0.0, 0.5 * dt3 * std_dev_a2, 0.0,
136                0.0, 0.0, 0.0, 0.25 * dt4 * std_dev_a2, 0.0, 0.0, 0.0, 0.5 * dt3 * std_dev_a2,
137                0.5 * dt3 * std_dev_a2, 0.0, 0.0, 0.0, dt2 * std_dev_a2, 0.0, 0.0, 0.0,
138                0.0, 0.5 * dt3 * std_dev_a2, 0.0, 0.0, 0.0, dt2 * std_dev_a2, 0.0, 0.0,
139                0.0, 0.0, 0.5 * dt3 * std_dev_a2, 0.0, 0.0, 0.0, dt2 * std_dev_a2, 0.0,
140                0.0, 0.0, 0.0, 0.5 * dt3 * std_dev_a2, 0.0, 0.0, 0.0, dt2 * std_dev_a2,
141            ]),
142            // Ref.: Eq.(63)
143            R: nalgebra::SMatrix::<f32, 4, 4>::new(
144                std_dev_m_cx * std_dev_m_cx, 0.0, 0.0, 0.0,
145                0.0, std_dev_m_cy * std_dev_m_cy, 0.0, 0.0,
146                0.0, 0.0, std_dev_m_w * std_dev_m_w, 0.0,
147                0.0, 0.0, 0.0, std_dev_m_h * std_dev_m_h,
148            ),
149            #[rustfmt::skip]
150            P: nalgebra::SMatrix::<f32, 8, 8>::from_row_slice(&[
151                1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
152                0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
153                0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0,
154                0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0,
155                0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,
156                0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
157                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0,
158                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,
159            ]),
160            #[rustfmt::skip]
161            x: nalgebra::SVector::<f32, 8>::from_row_slice(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]),
162        }
163    }
164
165    /// Creates new `KalmanBBox` with initial state
166    ///
167    /// Why is it needed to set the initial state to the actual first observed bounding box (sometimes)?
168    /// When the first state vector is initialized with zeros, it assumes that the object is at the origin
169    /// and the filter needs to estimate the position of the object from scratch, which can result in some initial inaccuracies.
170    /// On the other hand, initializing the first state vector with the actual observed bounding box can provide
171    /// a more accurate estimate from the beginning, which can improve the overall tracking performance of the filter
172    ///
173    /// Basic usage:
174    ///
175    /// ```
176    /// use kalman_rust::kalman::KalmanBBox;
177    /// let dt = 0.04; // Single cycle time (1/25 fps)
178    /// let u_cx = 1.0; // Control input for center X
179    /// let u_cy = 1.0; // Control input for center Y
180    /// let u_w = 0.0; // Control input for width
181    /// let u_h = 0.0; // Control input for height
182    /// let std_dev_a = 2.0; // Standart deviation of acceleration
183    /// let std_dev_m_cx = 0.1; // Standart deviation of measurement for center X
184    /// let std_dev_m_cy = 0.1; // Standart deviation of measurement for center Y
185    /// let std_dev_m_w = 0.1; // Standart deviation of measurement for width
186    /// let std_dev_m_h = 0.1; // Standart deviation of measurement for height
187    /// let i_cx = 100.0; // Initial center X
188    /// let i_cy = 50.0; // Initial center Y
189    /// let i_w = 40.0; // Initial width
190    /// let i_h = 80.0; // Initial height
191    /// let mut kalman = KalmanBBox::new_with_state(dt, u_cx, u_cy, u_w, u_h, std_dev_a, std_dev_m_cx, std_dev_m_cy, std_dev_m_w, std_dev_m_h, i_cx, i_cy, i_w, i_h);
192    /// ```
193    pub fn new_with_state(
194        dt: f32,
195        u_cx: f32,
196        u_cy: f32,
197        u_w: f32,
198        u_h: f32,
199        std_dev_a: f32,
200        std_dev_m_cx: f32,
201        std_dev_m_cy: f32,
202        std_dev_m_w: f32,
203        std_dev_m_h: f32,
204        cx: f32,
205        cy: f32,
206        w: f32,
207        h: f32,
208    ) -> Self {
209        let dt2 = dt * dt;
210        let dt3 = dt2 * dt;
211        let dt4 = dt3 * dt;
212        let std_dev_a2 = std_dev_a * std_dev_a;
213
214        KalmanBBox {
215            dt,
216            u: nalgebra::SMatrix::<f32, 4, 1>::new(u_cx, u_cy, u_w, u_h),
217            std_dev_a,
218            std_dev_m_cx,
219            std_dev_m_cy,
220            std_dev_m_w,
221            std_dev_m_h,
222            // Ref.: Eq.(53)
223            #[rustfmt::skip]
224            A: nalgebra::SMatrix::<f32, 8, 8>::from_row_slice(&[
225                1.0, 0.0, 0.0, 0.0, dt, 0.0, 0.0, 0.0,
226                0.0, 1.0, 0.0, 0.0, 0.0, dt, 0.0, 0.0,
227                0.0, 0.0, 1.0, 0.0, 0.0, 0.0, dt, 0.0,
228                0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, dt,
229                0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,
230                0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
231                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0,
232                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,
233            ]),
234            // Ref.: Eq.(54)
235            #[rustfmt::skip]
236            B: nalgebra::SMatrix::<f32, 8, 4>::from_row_slice(&[
237                0.5 * dt2, 0.0, 0.0, 0.0,
238                0.0, 0.5 * dt2, 0.0, 0.0,
239                0.0, 0.0, 0.5 * dt2, 0.0,
240                0.0, 0.0, 0.0, 0.5 * dt2,
241                dt, 0.0, 0.0, 0.0,
242                0.0, dt, 0.0, 0.0,
243                0.0, 0.0, dt, 0.0,
244                0.0, 0.0, 0.0, dt,
245            ]),
246            // Ref.: Eq.(56)
247            #[rustfmt::skip]
248            H: nalgebra::SMatrix::<f32, 4, 8>::from_row_slice(&[
249                1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
250                0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
251                0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0,
252                0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0,
253            ]),
254            // Ref.: Eq.(62)
255            #[rustfmt::skip]
256            Q: nalgebra::SMatrix::<f32, 8, 8>::from_row_slice(&[
257                0.25 * dt4 * std_dev_a2, 0.0, 0.0, 0.0, 0.5 * dt3 * std_dev_a2, 0.0, 0.0, 0.0,
258                0.0, 0.25 * dt4 * std_dev_a2, 0.0, 0.0, 0.0, 0.5 * dt3 * std_dev_a2, 0.0, 0.0,
259                0.0, 0.0, 0.25 * dt4 * std_dev_a2, 0.0, 0.0, 0.0, 0.5 * dt3 * std_dev_a2, 0.0,
260                0.0, 0.0, 0.0, 0.25 * dt4 * std_dev_a2, 0.0, 0.0, 0.0, 0.5 * dt3 * std_dev_a2,
261                0.5 * dt3 * std_dev_a2, 0.0, 0.0, 0.0, dt2 * std_dev_a2, 0.0, 0.0, 0.0,
262                0.0, 0.5 * dt3 * std_dev_a2, 0.0, 0.0, 0.0, dt2 * std_dev_a2, 0.0, 0.0,
263                0.0, 0.0, 0.5 * dt3 * std_dev_a2, 0.0, 0.0, 0.0, dt2 * std_dev_a2, 0.0,
264                0.0, 0.0, 0.0, 0.5 * dt3 * std_dev_a2, 0.0, 0.0, 0.0, dt2 * std_dev_a2,
265            ]),
266            // Ref.: Eq.(63)
267            R: nalgebra::SMatrix::<f32, 4, 4>::new(
268                std_dev_m_cx * std_dev_m_cx, 0.0, 0.0, 0.0,
269                0.0, std_dev_m_cy * std_dev_m_cy, 0.0, 0.0,
270                0.0, 0.0, std_dev_m_w * std_dev_m_w, 0.0,
271                0.0, 0.0, 0.0, std_dev_m_h * std_dev_m_h,
272            ),
273            #[rustfmt::skip]
274            P: nalgebra::SMatrix::<f32, 8, 8>::from_row_slice(&[
275                1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
276                0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
277                0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0,
278                0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0,
279                0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,
280                0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
281                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0,
282                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,
283            ]),
284            #[rustfmt::skip]
285            x: nalgebra::SVector::<f32, 8>::from_row_slice(&[cx, cy, w, h, 0.0, 0.0, 0.0, 0.0]),
286        }
287    }
288
289    /// Projects the state and the error covariance ahead
290    /// Mutates the state vector and the error covariance matrix
291    ///
292    /// Basic usage:
293    ///
294    /// ```
295    /// use kalman_rust::kalman::KalmanBBox;
296    /// let dt = 0.04;
297    /// let mut kalman = KalmanBBox::new(dt, 1.0, 1.0, 0.0, 0.0, 2.0, 0.1, 0.1, 0.1, 0.1);
298    /// let measurements = vec![(100.0, 50.0, 40.0, 80.0), (102.0, 52.0, 41.0, 81.0)];
299    /// for _ in measurements.iter() {
300    ///     // get measurement
301    ///     kalman.predict();
302    ///     // then do update
303    /// }
304    /// ```
305    pub fn predict(&mut self) {
306        // Ref.: Eq.(5)
307        self.x = (self.A * self.x) + (self.B * self.u);
308        // Ref.: Eq.(6)
309        self.P = self.A * self.P * self.A.transpose() + self.Q;
310    }
311
312    /// Computes the Kalman gain and then updates the state vector and the error covariance matrix
313    /// Mutates the state vector and the error covariance matrix.
314    ///
315    /// Basic usage:
316    ///
317    /// ```
318    /// use kalman_rust::kalman::KalmanBBox;
319    /// let dt = 0.04;
320    /// let mut kalman = KalmanBBox::new(dt, 1.0, 1.0, 0.0, 0.0, 2.0, 0.1, 0.1, 0.1, 0.1);
321    /// let measurements = vec![(100.0, 50.0, 40.0, 80.0), (102.0, 52.0, 41.0, 81.0)];
322    /// for (cx, cy, w, h) in measurements.iter() {
323    ///     kalman.predict();
324    ///     kalman.update(*cx, *cy, *w, *h).unwrap();
325    /// }
326    /// ```
327    pub fn update(&mut self, _z_cx: f32, _z_cy: f32, _z_w: f32, _z_h: f32) -> Result<(), KalmanBBoxError> {
328        // Ref.: Eq.(7)
329        let S = self.H * self.P * self.H.transpose() + self.R;
330        let gain = match S.try_inverse() {
331            Some(inv) => self.P * self.H.transpose() * inv,
332            None => return Err(KalmanBBoxError { typ: 1 }),
333        };
334        // Ref.: Eq.(8)
335        let z = nalgebra::SMatrix::<f32, 4, 1>::new(_z_cx, _z_cy, _z_w, _z_h);
336        let r = z - self.H * self.x;
337        // Ref.: Eq.(9)
338        self.x = self.x + gain * r;
339        // Ref.: Eq.(10)
340        // Identity matrix for 8x8
341        #[rustfmt::skip]
342        let I: nalgebra::SMatrix<f32, 8, 8> = nalgebra::SMatrix::<f32, 8, 8>::from_row_slice(&[
343            1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
344            0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
345            0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0,
346            0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0,
347            0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,
348            0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
349            0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0,
350            0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,
351        ]);
352        self.P = (I - gain * self.H) * self.P;
353        Ok(())
354    }
355
356    /// Returns the current state (only cx, cy, w, h - not velocities)
357    pub fn get_state(&self) -> (f32, f32, f32, f32) {
358        (self.x[0], self.x[1], self.x[2], self.x[3])
359    }
360
361    /// Returns the current velocity (only vx, vy, vw, vh)
362    pub fn get_velocity(&self) -> (f32, f32, f32, f32) {
363        (self.x[4], self.x[5], self.x[6], self.x[7])
364    }
365
366    /// Returns prediction without mutating the state vector and the error covariance matrix
367    pub fn get_predicted_state(&self) -> (f32, f32, f32, f32) {
368        let x_pred = (self.A * self.x) + (self.B * self.u);
369        (x_pred[0], x_pred[1], x_pred[2], x_pred[3])
370    }
371
372    /// Returns position uncertainty from P matrix (for center coordinates)
373    pub fn get_position_uncertainty(&self) -> f32 {
374        (self.P[(0, 0)].powi(2) + self.P[(1, 1)].powi(2)).sqrt()
375    }
376
377    /// Returns the current state (full 8-element vector)
378    pub fn get_vector_state(&self) -> nalgebra::SVector<f32, 8> {
379        self.x
380    }
381
382    /// Computes the squared Mahalanobis distance between measurement and predicted state
383    /// Ref.: Eq.(66)
384    ///
385    /// This is useful for data association in multi-object tracking.
386    /// A detection passes the gate if mahalanobis_distance_squared < threshold
387    /// Common threshold for 4 DOF at 95% confidence: 9.488
388    pub fn mahalanobis_distance_squared(&self, z_cx: f32, z_cy: f32, z_w: f32, z_h: f32) -> Result<f32, KalmanBBoxError> {
389        // Ref.: Eq.(64) - Innovation (residual)
390        let z = nalgebra::SMatrix::<f32, 4, 1>::new(z_cx, z_cy, z_w, z_h);
391        let y = z - self.H * self.x;
392
393        // Ref.: Eq.(65) - Innovation covariance
394        let S = self.H * self.P * self.H.transpose() + self.R;
395
396        // Ref.: Eq.(66) - Squared Mahalanobis distance
397        let S_inv = match S.try_inverse() {
398            Some(inv) => inv,
399            None => return Err(KalmanBBoxError { typ: 1 }),
400        };
401
402        let d_squared = (y.transpose() * S_inv * y)[(0, 0)];
403        Ok(d_squared)
404    }
405
406    /// Computes the Mahalanobis distance (square root of squared distance)
407    /// Ref.: Eq.(67)
408    pub fn mahalanobis_distance(&self, z_cx: f32, z_cy: f32, z_w: f32, z_h: f32) -> Result<f32, KalmanBBoxError> {
409        let d_squared = self.mahalanobis_distance_squared(z_cx, z_cy, z_w, z_h)?;
410        Ok(d_squared.sqrt())
411    }
412
413    /// Returns the innovation covariance matrix S
414    /// Ref.: Eq.(65)
415    ///
416    /// S = H * P * H^T + R
417    pub fn get_innovation_covariance(&self) -> nalgebra::SMatrix<f32, 4, 4> {
418        self.H * self.P * self.H.transpose() + self.R
419    }
420
421    /// Checks if a measurement passes the gating threshold
422    /// Ref.: Eq.(70)
423    ///
424    /// Common thresholds for alpha = 0.95:
425    /// - 4 DOF (bounding box): 9.488
426    /// - 2 DOF (position only): 5.991
427    pub fn gating_check(&self, z_cx: f32, z_cy: f32, z_w: f32, z_h: f32, threshold: f32) -> bool {
428        match self.mahalanobis_distance_squared(z_cx, z_cy, z_w, z_h) {
429            Ok(d_sq) => d_sq < threshold,
430            Err(_) => false,
431        }
432    }
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438
439    #[test]
440    fn test_bbox_kalman() {
441        let dt = 0.04; // 1/25 = 25 fps - just an example
442        let u_cx = 1.0;
443        let u_cy = 1.0;
444        let u_w = 0.0; // No control input for size
445        let u_h = 0.0;
446        let std_dev_a = 2.0;
447        let std_dev_m_cx = 0.1;
448        let std_dev_m_cy = 0.1;
449        let std_dev_m_w = 0.1;
450        let std_dev_m_h = 0.1;
451
452        // Sample measurements for center coordinates (same as test_2d_kalman)
453        // Note: in this example Y-axis going from up to down
454        let xs = vec![
455            311, 312, 313, 311, 311, 312, 312, 313, 312, 312, 312, 312, 312, 312, 312, 312, 312,
456            312, 311, 311, 311, 311, 311, 310, 311, 311, 311, 310, 310, 308, 307, 308, 308, 308,
457            307, 307, 307, 308, 307, 307, 307, 307, 307, 308, 307, 309, 306, 307, 306, 307, 308,
458            306, 306, 306, 305, 307, 307, 307, 306, 306, 306, 307, 307, 308, 307, 307, 308, 307,
459            306, 308, 309, 309, 309, 309, 308, 309, 309, 309, 308, 311, 311, 307, 311, 307, 313,
460            311, 307, 311, 311, 306, 312, 312, 312, 312, 312, 312, 312, 312, 312, 312, 312, 312,
461            312, 312, 312, 312, 312, 312, 312, 312, 312, 312,
462        ];
463        let ys = vec![
464            5, 6, 8, 10, 11, 12, 12, 13, 16, 16, 18, 18, 19, 19, 20, 20, 22, 22, 23, 23, 24, 24,
465            28, 30, 32, 35, 39, 42, 44, 46, 56, 58, 70, 60, 52, 64, 51, 70, 70, 70, 66, 83, 80, 85,
466            80, 98, 79, 98, 61, 94, 101, 94, 104, 94, 107, 112, 108, 108, 109, 109, 121, 108, 108,
467            120, 122, 122, 128, 130, 122, 140, 122, 122, 140, 122, 134, 141, 136, 136, 154, 155,
468            155, 150, 161, 162, 169, 171, 181, 175, 175, 163, 178, 178, 178, 178, 178, 178, 178,
469            178, 178, 178, 178, 178, 178, 178, 178, 178, 178, 178, 178, 178, 178, 178,
470        ];
471
472        // Generate bounding boxes around centroids
473        // Base width/height with some variation
474        let base_w = 40.0_f32;
475        let base_h = 80.0_f32;
476
477        // Generate width/height with slight variations (simulating object size changes)
478        let ws: Vec<f32> = xs.iter().enumerate().map(|(i, _)| {
479            // Slight oscillation
480            base_w + (i as f32 * 0.1).sin() * 5.0
481        }).collect();
482
483        let hs: Vec<f32> = ys.iter().enumerate().map(|(i, _)| {
484            // Slight oscillation
485            base_h + (i as f32 * 0.15).sin() * 8.0
486        }).collect();
487
488        // Assume that initial state matches the first measurement
489        let i_cx = xs[0] as f32;
490        let i_cy = ys[0] as f32;
491        let i_w = ws[0];
492        let i_h = hs[0];
493
494        let mut kalman = KalmanBBox::new_with_state(
495            dt, u_cx, u_cy, u_w, u_h,
496            std_dev_a,
497            std_dev_m_cx, std_dev_m_cy, std_dev_m_w, std_dev_m_h,
498            i_cx, i_cy, i_w, i_h
499        );
500
501        let mut predictions: Vec<(f32, f32, f32, f32)> = vec![];
502        let mut updated_states: Vec<(f32, f32, f32, f32)> = vec![];
503
504        for i in 0..xs.len() {
505            let m_cx = xs[i] as f32;
506            let m_cy = ys[i] as f32;
507            let m_w = ws[i];
508            let m_h = hs[i];
509
510            // Predict stage
511            kalman.predict();
512            predictions.push(kalman.get_state());
513
514            // Update stage
515            kalman.update(m_cx, m_cy, m_w, m_h).unwrap();
516            updated_states.push(kalman.get_state());
517        }
518
519        // Verify that predictions and updates are reasonable
520        assert!(predictions.len() == xs.len());
521        assert!(updated_states.len() == xs.len());
522
523        // Check that final state is close to final measurement
524        let (final_cx, final_cy, final_w, final_h) = updated_states.last().unwrap();
525        let last_m_cx = *xs.last().unwrap() as f32;
526        let last_m_cy = *ys.last().unwrap() as f32;
527        let last_m_w = *ws.last().unwrap();
528        let last_m_h = *hs.last().unwrap();
529
530        assert!((final_cx - last_m_cx).abs() < 5.0, "Final cx too far from measurement");
531        assert!((final_cy - last_m_cy).abs() < 5.0, "Final cy too far from measurement");
532        assert!((final_w - last_m_w).abs() < 5.0, "Final w too far from measurement");
533        assert!((final_h - last_m_h).abs() < 5.0, "Final h too far from measurement");
534
535        // println!("measurement cx;measurement cy;measurement w;measurement h;prediction cx;prediction cy;prediction w;prediction h;updated cx;updated cy;updated w;updated h");
536        // for i in 0..xs.len() {
537        //     let (p_cx, p_cy, p_w, p_h) = predictions[i];
538        //     let (u_cx, u_cy, u_w, u_h) = updated_states[i];
539        //     println!("{};{};{};{};{};{};{};{};{};{};{};{}",
540        //         xs[i], ys[i], ws[i], hs[i],
541        //         p_cx, p_cy, p_w, p_h,
542        //         u_cx, u_cy, u_w, u_h
543        //     );
544        // }
545    }
546
547    #[test]
548    fn test_mahalanobis_distance() {
549        let dt = 0.04;
550        let mut kalman = KalmanBBox::new_with_state(
551            dt, 1.0, 1.0, 0.0, 0.0,
552            2.0, 0.1, 0.1, 0.1, 0.1,
553            100.0, 50.0, 40.0, 80.0
554        );
555
556        // After prediction, covariance increases
557        kalman.predict();
558
559        // Measurement close to predicted state should have small Mahalanobis distance
560        let d_close = kalman.mahalanobis_distance_squared(100.5, 50.5, 40.0, 80.0).unwrap();
561
562        // Measurement far from predicted state should have large Mahalanobis distance
563        let d_far = kalman.mahalanobis_distance_squared(200.0, 150.0, 60.0, 120.0).unwrap();
564
565        assert!(d_close < d_far, "Close measurement should have smaller distance");
566
567        // Check gating
568        let threshold = 9.488; // 95% confidence for 4 DOF
569        assert!(kalman.gating_check(100.5, 50.5, 40.0, 80.0, threshold), "Close measurement should pass gate");
570    }
571
572    #[test]
573    fn test_velocity_tracking() {
574        let dt = 0.04;
575        let mut kalman = KalmanBBox::new_with_state(
576            dt, 0.0, 0.0, 0.0, 0.0, // No control input
577            2.0, 0.1, 0.1, 0.1, 0.1,
578            100.0, 50.0, 40.0, 80.0
579        );
580
581        // Simulate object moving right and down with growing size
582        for i in 1..20 {
583            let cx = 100.0 + i as f32 * 2.0; // Moving right
584            let cy = 50.0 + i as f32 * 1.5;  // Moving down
585            let w = 40.0 + i as f32 * 0.5;   // Growing width
586            let h = 80.0 + i as f32 * 0.3;   // Growing height
587
588            kalman.predict();
589            kalman.update(cx, cy, w, h).unwrap();
590        }
591
592        // After tracking, velocity should reflect the motion
593        let (vx, vy, vw, vh) = kalman.get_velocity();
594
595        // Velocity should be positive (moving right/down, growing)
596        assert!(vx > 0.0, "Velocity X should be positive");
597        assert!(vy > 0.0, "Velocity Y should be positive");
598        assert!(vw > 0.0, "Velocity W should be positive");
599        assert!(vh > 0.0, "Velocity H should be positive");
600    }
601}