Skip to main content

edgefirst_tracker/
kalman.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use nalgebra::{
5    allocator::Allocator, convert, dimension::U4, DVector, DefaultAllocator, Dyn, OMatrix,
6    RealField, SVector, U1, U8,
7};
8
9#[derive(Debug, Clone)]
10pub struct ConstantVelocityXYAHModel2<R>
11where
12    R: RealField,
13    DefaultAllocator: Allocator<U8, U8>,
14    DefaultAllocator: Allocator<U8>,
15{
16    pub mean: SVector<R, 8>,
17    pub std_weight_position: R,
18    pub std_weight_velocity: R,
19    pub update_factor: R,
20    motion_matrix: OMatrix<R, U8, U8>,
21    update_matrix: OMatrix<R, U4, U8>,
22    pub covariance: OMatrix<R, U8, U8>,
23}
24
25#[allow(dead_code)]
26pub enum GatingDistanceMetric {
27    Gaussian,
28    Mahalanobis,
29}
30
31impl<R> ConstantVelocityXYAHModel2<R>
32where
33    R: RealField + Copy,
34{
35    pub fn new(measurement: &[R; 4], update_factor: R) -> Self {
36        let ndim = 4;
37        let dt: R = convert(1.0);
38
39        let mut motion_matrix = OMatrix::<R, U8, U8>::identity();
40        for i in 0..ndim {
41            motion_matrix[(i, ndim + i)] = dt * convert(3.0);
42        }
43        let mut update_matrix = OMatrix::<R, U4, U8>::identity();
44        for i in 0..ndim {
45            update_matrix[(i, ndim + i)] = dt * convert(1.0);
46        }
47        let zero: R = convert(0.0);
48        let two: R = convert(2.0);
49        let ten: R = convert(10.0);
50        let height = measurement[3];
51
52        let mean = SVector::<R, 8>::from_row_slice(&[
53            measurement[0],
54            measurement[1],
55            measurement[2],
56            measurement[3],
57            zero,
58            zero,
59            zero,
60            zero,
61        ]);
62        let std_weight_position = convert(1.0 / 20.0);
63        let std_weight_velocity = convert(1.0 / 160.0);
64        let diag = [
65            two * std_weight_position * height,
66            two * std_weight_position * height,
67            convert(0.01),
68            two * std_weight_position * height,
69            ten * std_weight_velocity * height,
70            ten * std_weight_velocity * height,
71            convert(0.00001),
72            ten * std_weight_velocity * height,
73        ];
74        let diag = SVector::<R, 8>::from_row_slice(&diag);
75
76        let covariance = OMatrix::<R, U8, U8>::from_diagonal(&diag.component_mul(&diag));
77        Self {
78            motion_matrix,
79            update_matrix,
80            mean,
81            covariance,
82            std_weight_position,
83            std_weight_velocity,
84            update_factor,
85        }
86    }
87
88    pub fn predict(&mut self) {
89        let height = self.mean[3];
90        let diag = [
91            self.std_weight_position * height,
92            self.std_weight_position * height,
93            convert(0.01),
94            self.std_weight_position * height,
95            self.std_weight_velocity * height,
96            self.std_weight_velocity * height,
97            convert(0.00001),
98            self.std_weight_velocity * height,
99        ];
100        let diag = SVector::<R, 8>::from_row_slice(&diag);
101        let motion_cov = OMatrix::<R, U8, U8>::from_diagonal(&diag.component_mul(&diag));
102
103        let mean = (self.mean.transpose() * self.motion_matrix.transpose()).transpose();
104        let covariance =
105            self.motion_matrix * self.covariance * self.motion_matrix.transpose() + motion_cov;
106        self.mean = mean;
107        self.covariance = covariance;
108    }
109
110    pub fn project(&self) -> (OMatrix<R, U4, U1>, OMatrix<R, U4, U4>) {
111        let height = self.mean[3];
112        let diag = [
113            self.std_weight_position * height,
114            self.std_weight_position * height,
115            convert(0.01),
116            self.std_weight_position * height,
117        ];
118        let diag = SVector::<R, 4>::from_row_slice(&diag);
119        let innovation_cov = OMatrix::<R, U4, U4>::from_diagonal(&diag.component_mul(&diag));
120        let mean = self.update_matrix * self.mean;
121        let covariance =
122            self.update_matrix * self.covariance * self.update_matrix.transpose() + innovation_cov;
123        (mean, covariance)
124    }
125
126    pub fn update(&mut self, measurement: &[R; 4]) {
127        let measurement = SVector::<R, 4>::from_row_slice(&[
128            measurement[0],
129            measurement[1],
130            measurement[2],
131            measurement[3],
132        ]);
133
134        let (projected_mean, projected_cov) = self.project();
135        let cho_factor = match projected_cov.cholesky() {
136            None => return,
137            Some(v) => v,
138        };
139        let kalman_gain = cho_factor
140            .solve(&(self.covariance * self.update_matrix.transpose()).transpose())
141            .transpose();
142
143        let innovation = (measurement - projected_mean).scale(self.update_factor);
144        // println!("kalman_gain={}", kalman_gain);
145        // println!("innovation={}", innovation);
146        let diff = innovation.transpose() * kalman_gain.transpose();
147        self.mean += diff.transpose();
148        self.covariance -= kalman_gain * projected_cov * kalman_gain.transpose();
149        // let new_mean = self.mean + diff.transpose();
150        // let new_cov = self.covariance - kalman_gain * projected_cov *
151        // kalman_gain.transpose();
152    }
153
154    #[allow(dead_code)]
155    pub fn gating_distance(
156        &self,
157        measurements: &OMatrix<R, Dyn, U4>,
158        only_position: bool,
159        metric: GatingDistanceMetric,
160    ) -> DVector<R> {
161        let (m, cov) = self.project();
162        let ndims = if only_position { 2 } else { 4 };
163        let mean = m.transpose();
164        let mean = mean.columns_range(0..ndims);
165        let covariance = cov.view_range(0..ndims, 0..ndims);
166        let measurements = measurements.columns_range(0..ndims);
167        // let _ = only_position;
168        // let mean = m.transpose();
169        // let covariance = cov;
170        // let measurements = measurements;
171
172        let mut mean_broadcast =
173            OMatrix::<R, Dyn, U4>::from_element(measurements.shape().0, convert(0.0));
174        for mut col in mean_broadcast.row_iter_mut() {
175            col.copy_from(&mean);
176        }
177        let d = measurements - mean_broadcast;
178        match metric {
179            GatingDistanceMetric::Gaussian => d.component_mul(&d).column_sum(),
180            GatingDistanceMetric::Mahalanobis => {
181                let cho_factor = match covariance.cholesky() {
182                    None => return DVector::<R>::zeros(measurements.shape().0),
183                    Some(v) => v,
184                };
185                let z = cho_factor.solve(&d.transpose());
186                z.component_mul(&z).row_sum_tr()
187            }
188        }
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use nalgebra::{Dyn, OMatrix, U4};
195
196    use super::{ConstantVelocityXYAHModel2, GatingDistanceMetric};
197    #[test]
198    fn filter() {
199        let mut t = ConstantVelocityXYAHModel2::new(&[0.5, 0.5, 1.0, 0.5], 0.25);
200        t.predict();
201        println!("1. t.mean={}", t.mean);
202        t.update(&[0.4, 0.5, 1.0, 0.5]);
203        t.predict();
204        println!("2. t.mean={}", t.mean);
205        t.update(&[0.3, 0.5, 1.0, 0.5]);
206        t.predict();
207        println!("3. t.mean={}", t.mean);
208        t.update(&[0.2, 0.5, 1.0, 0.5]);
209        t.predict();
210        println!("4. t.mean={}", t.mean);
211        t.update(&[0.2, 0.5, 1.0, 0.5]);
212        t.predict();
213        println!("5. t.mean={}", t.mean);
214        t.update(&[0.3, 0.5, 1.0, 0.5]);
215        t.predict();
216        println!("6. t.mean={}", t.mean);
217        t.update(&[0.4, 0.5, 1.0, 0.5]);
218    }
219
220    #[test]
221    fn gating() {
222        let mut t = ConstantVelocityXYAHModel2::new(&[0.5, 0.5, 1.0, 0.5], 0.25);
223        t.predict();
224        t.update(&[0.49, 0.5, 1.0, 0.5]);
225        t.predict();
226        t.update(&[0.48, 0.5, 1.0, 0.5]);
227        t.predict();
228        t.update(&[0.47, 0.5, 1.0, 0.5]);
229        t.predict();
230        t.update(&[0.46, 0.5, 1.0, 0.5]);
231        t.predict();
232        t.update(&[0.45, 0.5, 1.0, 0.5]);
233        t.predict();
234        t.update(&[0.44, 0.5, 1.0, 0.5]);
235        t.predict();
236        t.update(&[0.43, 0.5, 1.0, 0.5]);
237        t.predict();
238        t.update(&[0.42, 0.5, 1.0, 0.5]);
239        t.predict();
240
241        // distances range from 0 to 1e6 for maha
242        let mut measurements = OMatrix::<f32, Dyn, U4>::from_element(1, 0.0);
243        measurements.copy_from_slice(&[0.3, 0.5, 1.0, 0.5]);
244
245        let mut distances = OMatrix::<f32, Dyn, Dyn>::from_element(1, 1, 0.0);
246        for mut column in distances.column_iter_mut() {
247            let dist = t.gating_distance(&measurements, false, GatingDistanceMetric::Gaussian);
248            column.copy_from(&dist);
249        }
250        let dist = t.gating_distance(&measurements, false, GatingDistanceMetric::Mahalanobis);
251        println!("Dist(false, maha): {dist}");
252
253        let dist = t.gating_distance(&measurements, false, GatingDistanceMetric::Gaussian);
254        println!("Dist(false, gaussian): {dist}");
255    }
256
257    #[test]
258    fn test_predict_constant_velocity() {
259        // Initialize filter and give it a few updates to establish velocity,
260        // then verify predictions drift in the expected direction.
261        let mut t = ConstantVelocityXYAHModel2::new(&[0.5, 0.5, 0.1, 2.0], 0.25);
262        t.predict();
263        t.update(&[0.5, 0.5, 0.1, 2.0]);
264
265        // Record position after first predict-update cycle
266        let x_before: f32 = t.mean[0];
267        let y_before: f32 = t.mean[1];
268
269        // Run several predict-only cycles to let velocity dominate
270        for _ in 0..5 {
271            t.predict();
272        }
273
274        let x_after: f32 = t.mean[0];
275        let y_after: f32 = t.mean[1];
276        let h_after: f32 = t.mean[3];
277
278        // The position should remain numerically reasonable (no NaN/Inf)
279        assert!(x_after.is_finite(), "x should be finite after predictions");
280        assert!(y_after.is_finite(), "y should be finite after predictions");
281        assert!(
282            h_after.is_finite(),
283            "height should be finite after predictions"
284        );
285
286        // With near-zero velocity the predicted position should not explode
287        assert!(
288            (x_after - x_before).abs() < 5.0,
289            "x drift should be bounded, got delta={}",
290            (x_after - x_before).abs()
291        );
292        assert!(
293            (y_after - y_before).abs() < 5.0,
294            "y drift should be bounded, got delta={}",
295            (y_after - y_before).abs()
296        );
297    }
298
299    #[test]
300    fn test_numerical_stability_1000_cycles() {
301        let mut t = ConstantVelocityXYAHModel2::new(&[0.5, 0.5, 1.0, 0.5], 0.25);
302
303        // Run 1000 predict-only cycles without any update
304        for _ in 0..1000 {
305            t.predict();
306        }
307
308        // Verify no NaN or Inf in the mean vector
309        for i in 0..8 {
310            let val: f32 = t.mean[i];
311            assert!(
312                val.is_finite(),
313                "mean[{i}] should be finite after 1000 predictions, got {val}",
314            );
315        }
316
317        // Verify no NaN or Inf in the covariance matrix
318        for r in 0..8 {
319            for c in 0..8 {
320                let val: f32 = t.covariance[(r, c)];
321                assert!(
322                    val.is_finite(),
323                    "covariance[({r},{c})] should be finite after 1000 predictions, got {val}",
324                );
325            }
326        }
327    }
328
329    #[test]
330    fn test_gating_distance_edge_cases() {
331        let mut t = ConstantVelocityXYAHModel2::new(&[0.5, 0.5, 1.0, 0.5], 0.25);
332        // Run a few predict-update cycles to stabilize
333        for _ in 0..3 {
334            t.predict();
335            t.update(&[0.5, 0.5, 1.0, 0.5]);
336        }
337        t.predict();
338
339        // Measurement exactly at the predicted state -- distance should be near 0
340        let (projected_mean, _) = t.project();
341        let mut meas_close = OMatrix::<f32, Dyn, U4>::from_element(1, 0.0);
342        meas_close
343            .row_mut(0)
344            .copy_from_slice(projected_mean.as_slice());
345
346        let dist_close = t.gating_distance(&meas_close, false, GatingDistanceMetric::Mahalanobis);
347        assert!(
348            dist_close[0].is_finite(),
349            "Close-measurement distance should be finite"
350        );
351        assert!(
352            dist_close[0] < 1.0,
353            "Distance for exact-match measurement should be near 0, got {}",
354            dist_close[0]
355        );
356
357        // Measurement far away -- distance should be large
358        let mut meas_far = OMatrix::<f32, Dyn, U4>::from_element(1, 0.0);
359        meas_far.copy_from_slice(&[10.0, 10.0, 5.0, 10.0]);
360
361        let dist_far = t.gating_distance(&meas_far, false, GatingDistanceMetric::Mahalanobis);
362        assert!(
363            dist_far[0].is_finite(),
364            "Far-measurement distance should be finite"
365        );
366        assert!(
367            dist_far[0] > dist_close[0],
368            "Far measurement should have larger distance than close one: {} vs {}",
369            dist_far[0],
370            dist_close[0]
371        );
372    }
373
374    #[test]
375    fn test_update_moves_mean_toward_measurement() {
376        let mut t = ConstantVelocityXYAHModel2::new(&[0.5, 0.5, 1.0, 0.5], 0.25);
377        t.predict();
378
379        let x_before: f32 = t.mean[0];
380        // Update with a measurement shifted to the right
381        t.update(&[0.6, 0.5, 1.0, 0.5]);
382        let x_after: f32 = t.mean[0];
383
384        assert!(
385            x_after > x_before,
386            "Mean x should move toward the measurement (0.6), was {x_before}, now {x_after}"
387        );
388        assert!(
389            x_after <= 0.6,
390            "Mean x should not overshoot the measurement, got {x_after}"
391        );
392    }
393
394    #[test]
395    fn test_covariance_positive_diagonal() {
396        let t = ConstantVelocityXYAHModel2::new(&[0.5, 0.5, 1.0, 0.5], 0.25);
397
398        // All diagonal elements of the covariance should be positive
399        for i in 0..8 {
400            let val: f32 = t.covariance[(i, i)];
401            assert!(
402                val > 0.0,
403                "Covariance diagonal[{i}] should be positive, got {val}"
404            );
405        }
406    }
407
408    #[test]
409    fn test_predict_increases_uncertainty() {
410        let mut t = ConstantVelocityXYAHModel2::new(&[0.5, 0.5, 1.0, 0.5], 0.25);
411
412        let cov_before: f32 = t.covariance[(0, 0)];
413        t.predict();
414        let cov_after: f32 = t.covariance[(0, 0)];
415
416        assert!(
417            cov_after > cov_before,
418            "Predict should increase position uncertainty: {cov_before} -> {cov_after}"
419        );
420    }
421
422    #[test]
423    fn test_update_decreases_uncertainty() {
424        let mut t = ConstantVelocityXYAHModel2::new(&[0.5, 0.5, 1.0, 0.5], 0.25);
425        t.predict();
426
427        let cov_before: f32 = t.covariance[(0, 0)];
428        t.update(&[0.5, 0.5, 1.0, 0.5]);
429        let cov_after: f32 = t.covariance[(0, 0)];
430
431        assert!(
432            cov_after < cov_before,
433            "Update should decrease position uncertainty: {cov_before} -> {cov_after}"
434        );
435    }
436
437    #[test]
438    fn test_gating_distance_gaussian_vs_mahalanobis() {
439        let mut t = ConstantVelocityXYAHModel2::new(&[0.5, 0.5, 1.0, 0.5], 0.25);
440        for _ in 0..3 {
441            t.predict();
442            t.update(&[0.5, 0.5, 1.0, 0.5]);
443        }
444        t.predict();
445
446        let mut measurements = OMatrix::<f32, Dyn, U4>::from_element(1, 0.0);
447        measurements.copy_from_slice(&[0.6, 0.5, 1.0, 0.5]);
448
449        let dist_gauss = t.gating_distance(&measurements, false, GatingDistanceMetric::Gaussian);
450        let dist_maha = t.gating_distance(&measurements, false, GatingDistanceMetric::Mahalanobis);
451
452        assert!(dist_gauss[0].is_finite());
453        assert!(dist_maha[0].is_finite());
454
455        // Both should be non-negative for a non-zero offset
456        assert!(
457            dist_gauss[0] > 0.0,
458            "Gaussian distance should be > 0 for offset measurement"
459        );
460        assert!(
461            dist_maha[0] > 0.0,
462            "Mahalanobis distance should be > 0 for offset measurement"
463        );
464    }
465
466    #[test]
467    fn test_gating_distance_multiple_measurements() {
468        let mut t = ConstantVelocityXYAHModel2::new(&[0.5, 0.5, 1.0, 0.5], 0.25);
469        t.predict();
470        t.update(&[0.5, 0.5, 1.0, 0.5]);
471        t.predict();
472
473        // Two measurements: one close, one far
474        let mut measurements = OMatrix::<f32, Dyn, U4>::from_element(2, 0.0);
475        measurements
476            .row_mut(0)
477            .copy_from_slice(&[0.5, 0.5, 1.0, 0.5]); // close
478        measurements
479            .row_mut(1)
480            .copy_from_slice(&[5.0, 5.0, 1.0, 0.5]); // far
481
482        let dists = t.gating_distance(&measurements, false, GatingDistanceMetric::Mahalanobis);
483        assert_eq!(dists.len(), 2, "Should return one distance per measurement");
484        assert!(dists[0].is_finite());
485        assert!(dists[1].is_finite());
486        assert!(
487            dists[1] > dists[0],
488            "Far measurement should have larger distance: {} vs {}",
489            dists[1],
490            dists[0]
491        );
492    }
493
494    #[test]
495    fn test_initiate_mean_matches_measurement() {
496        let measurement = [0.3, 0.7, 1.5, 2.0];
497        let t = ConstantVelocityXYAHModel2::new(&measurement, 0.25);
498
499        // Position portion of mean should match the measurement exactly
500        let x: f32 = t.mean[0];
501        let y: f32 = t.mean[1];
502        let a: f32 = t.mean[2];
503        let h: f32 = t.mean[3];
504        assert!((x - 0.3).abs() < 1e-6, "Mean x should be 0.3, got {x}");
505        assert!((y - 0.7).abs() < 1e-6, "Mean y should be 0.7, got {y}");
506        assert!((a - 1.5).abs() < 1e-6, "Mean a should be 1.5, got {a}");
507        assert!((h - 2.0).abs() < 1e-6, "Mean h should be 2.0, got {h}");
508
509        // Velocity portion should be zero
510        for i in 4..8 {
511            let v: f32 = t.mean[i];
512            assert!((v).abs() < 1e-6, "Velocity mean[{i}] should be 0, got {v}");
513        }
514    }
515}