1use nalgebra::{
5 allocator::Allocator, convert, dimension::U4, DVector, DefaultAllocator, Dyn, OMatrix,
6 RealField, SVector, U1, U8,
7};
8
9#[derive(Debug, Clone)]
10pub struct ConstantVelocityXYAHModel2<R>
11where
12 R: RealField,
13 DefaultAllocator: Allocator<U8, U8>,
14 DefaultAllocator: Allocator<U8>,
15{
16 pub mean: SVector<R, 8>,
17 pub std_weight_position: R,
18 pub std_weight_velocity: R,
19 pub update_factor: R,
20 motion_matrix: OMatrix<R, U8, U8>,
21 update_matrix: OMatrix<R, U4, U8>,
22 pub covariance: OMatrix<R, U8, U8>,
23}
24
25#[allow(dead_code)]
26pub enum GatingDistanceMetric {
27 Gaussian,
28 Mahalanobis,
29}
30
31impl<R> ConstantVelocityXYAHModel2<R>
32where
33 R: RealField + Copy,
34{
35 pub fn new(measurement: &[R; 4], update_factor: R) -> Self {
36 let ndim = 4;
37 let dt: R = convert(1.0);
38
39 let mut motion_matrix = OMatrix::<R, U8, U8>::identity();
40 for i in 0..ndim {
41 motion_matrix[(i, ndim + i)] = dt * convert(3.0);
42 }
43 let mut update_matrix = OMatrix::<R, U4, U8>::identity();
44 for i in 0..ndim {
45 update_matrix[(i, ndim + i)] = dt * convert(1.0);
46 }
47 let zero: R = convert(0.0);
48 let two: R = convert(2.0);
49 let ten: R = convert(10.0);
50 let height = measurement[3];
51
52 let mean = SVector::<R, 8>::from_row_slice(&[
53 measurement[0],
54 measurement[1],
55 measurement[2],
56 measurement[3],
57 zero,
58 zero,
59 zero,
60 zero,
61 ]);
62 let std_weight_position = convert(1.0 / 20.0);
63 let std_weight_velocity = convert(1.0 / 160.0);
64 let diag = [
65 two * std_weight_position * height,
66 two * std_weight_position * height,
67 convert(0.01),
68 two * std_weight_position * height,
69 ten * std_weight_velocity * height,
70 ten * std_weight_velocity * height,
71 convert(0.00001),
72 ten * std_weight_velocity * height,
73 ];
74 let diag = SVector::<R, 8>::from_row_slice(&diag);
75
76 let covariance = OMatrix::<R, U8, U8>::from_diagonal(&diag.component_mul(&diag));
77 Self {
78 motion_matrix,
79 update_matrix,
80 mean,
81 covariance,
82 std_weight_position,
83 std_weight_velocity,
84 update_factor,
85 }
86 }
87
88 pub fn predict(&mut self) {
89 let height = self.mean[3];
90 let diag = [
91 self.std_weight_position * height,
92 self.std_weight_position * height,
93 convert(0.01),
94 self.std_weight_position * height,
95 self.std_weight_velocity * height,
96 self.std_weight_velocity * height,
97 convert(0.00001),
98 self.std_weight_velocity * height,
99 ];
100 let diag = SVector::<R, 8>::from_row_slice(&diag);
101 let motion_cov = OMatrix::<R, U8, U8>::from_diagonal(&diag.component_mul(&diag));
102
103 let mean = (self.mean.transpose() * self.motion_matrix.transpose()).transpose();
104 let covariance =
105 self.motion_matrix * self.covariance * self.motion_matrix.transpose() + motion_cov;
106 self.mean = mean;
107 self.covariance = covariance;
108 }
109
110 pub fn project(&self) -> (OMatrix<R, U4, U1>, OMatrix<R, U4, U4>) {
111 let height = self.mean[3];
112 let diag = [
113 self.std_weight_position * height,
114 self.std_weight_position * height,
115 convert(0.01),
116 self.std_weight_position * height,
117 ];
118 let diag = SVector::<R, 4>::from_row_slice(&diag);
119 let innovation_cov = OMatrix::<R, U4, U4>::from_diagonal(&diag.component_mul(&diag));
120 let mean = self.update_matrix * self.mean;
121 let covariance =
122 self.update_matrix * self.covariance * self.update_matrix.transpose() + innovation_cov;
123 (mean, covariance)
124 }
125
126 pub fn update(&mut self, measurement: &[R; 4]) {
127 let measurement = SVector::<R, 4>::from_row_slice(&[
128 measurement[0],
129 measurement[1],
130 measurement[2],
131 measurement[3],
132 ]);
133
134 let (projected_mean, projected_cov) = self.project();
135 let cho_factor = match projected_cov.cholesky() {
136 None => return,
137 Some(v) => v,
138 };
139 let kalman_gain = cho_factor
140 .solve(&(self.covariance * self.update_matrix.transpose()).transpose())
141 .transpose();
142
143 let innovation = (measurement - projected_mean).scale(self.update_factor);
144 let diff = innovation.transpose() * kalman_gain.transpose();
147 self.mean += diff.transpose();
148 self.covariance -= kalman_gain * projected_cov * kalman_gain.transpose();
149 }
153
154 #[allow(dead_code)]
155 pub fn gating_distance(
156 &self,
157 measurements: &OMatrix<R, Dyn, U4>,
158 only_position: bool,
159 metric: GatingDistanceMetric,
160 ) -> DVector<R> {
161 let (m, cov) = self.project();
162 let ndims = if only_position { 2 } else { 4 };
163 let mean = m.transpose();
164 let mean = mean.columns_range(0..ndims);
165 let covariance = cov.view_range(0..ndims, 0..ndims);
166 let measurements = measurements.columns_range(0..ndims);
167 let mut mean_broadcast =
173 OMatrix::<R, Dyn, U4>::from_element(measurements.shape().0, convert(0.0));
174 for mut col in mean_broadcast.row_iter_mut() {
175 col.copy_from(&mean);
176 }
177 let d = measurements - mean_broadcast;
178 match metric {
179 GatingDistanceMetric::Gaussian => d.component_mul(&d).column_sum(),
180 GatingDistanceMetric::Mahalanobis => {
181 let cho_factor = match covariance.cholesky() {
182 None => return DVector::<R>::zeros(measurements.shape().0),
183 Some(v) => v,
184 };
185 let z = cho_factor.solve(&d.transpose());
186 z.component_mul(&z).row_sum_tr()
187 }
188 }
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use nalgebra::{Dyn, OMatrix, U4};
195
196 use super::{ConstantVelocityXYAHModel2, GatingDistanceMetric};
197 #[test]
198 fn filter() {
199 let mut t = ConstantVelocityXYAHModel2::new(&[0.5, 0.5, 1.0, 0.5], 0.25);
200 t.predict();
201 println!("1. t.mean={}", t.mean);
202 t.update(&[0.4, 0.5, 1.0, 0.5]);
203 t.predict();
204 println!("2. t.mean={}", t.mean);
205 t.update(&[0.3, 0.5, 1.0, 0.5]);
206 t.predict();
207 println!("3. t.mean={}", t.mean);
208 t.update(&[0.2, 0.5, 1.0, 0.5]);
209 t.predict();
210 println!("4. t.mean={}", t.mean);
211 t.update(&[0.2, 0.5, 1.0, 0.5]);
212 t.predict();
213 println!("5. t.mean={}", t.mean);
214 t.update(&[0.3, 0.5, 1.0, 0.5]);
215 t.predict();
216 println!("6. t.mean={}", t.mean);
217 t.update(&[0.4, 0.5, 1.0, 0.5]);
218 }
219
220 #[test]
221 fn gating() {
222 let mut t = ConstantVelocityXYAHModel2::new(&[0.5, 0.5, 1.0, 0.5], 0.25);
223 t.predict();
224 t.update(&[0.49, 0.5, 1.0, 0.5]);
225 t.predict();
226 t.update(&[0.48, 0.5, 1.0, 0.5]);
227 t.predict();
228 t.update(&[0.47, 0.5, 1.0, 0.5]);
229 t.predict();
230 t.update(&[0.46, 0.5, 1.0, 0.5]);
231 t.predict();
232 t.update(&[0.45, 0.5, 1.0, 0.5]);
233 t.predict();
234 t.update(&[0.44, 0.5, 1.0, 0.5]);
235 t.predict();
236 t.update(&[0.43, 0.5, 1.0, 0.5]);
237 t.predict();
238 t.update(&[0.42, 0.5, 1.0, 0.5]);
239 t.predict();
240
241 let mut measurements = OMatrix::<f32, Dyn, U4>::from_element(1, 0.0);
243 measurements.copy_from_slice(&[0.3, 0.5, 1.0, 0.5]);
244
245 let mut distances = OMatrix::<f32, Dyn, Dyn>::from_element(1, 1, 0.0);
246 for mut column in distances.column_iter_mut() {
247 let dist = t.gating_distance(&measurements, false, GatingDistanceMetric::Gaussian);
248 column.copy_from(&dist);
249 }
250 let dist = t.gating_distance(&measurements, false, GatingDistanceMetric::Mahalanobis);
251 println!("Dist(false, maha): {dist}");
252
253 let dist = t.gating_distance(&measurements, false, GatingDistanceMetric::Gaussian);
254 println!("Dist(false, gaussian): {dist}");
255 }
256
257 #[test]
258 fn test_predict_constant_velocity() {
259 let mut t = ConstantVelocityXYAHModel2::new(&[0.5, 0.5, 0.1, 2.0], 0.25);
262 t.predict();
263 t.update(&[0.5, 0.5, 0.1, 2.0]);
264
265 let x_before: f32 = t.mean[0];
267 let y_before: f32 = t.mean[1];
268
269 for _ in 0..5 {
271 t.predict();
272 }
273
274 let x_after: f32 = t.mean[0];
275 let y_after: f32 = t.mean[1];
276 let h_after: f32 = t.mean[3];
277
278 assert!(x_after.is_finite(), "x should be finite after predictions");
280 assert!(y_after.is_finite(), "y should be finite after predictions");
281 assert!(
282 h_after.is_finite(),
283 "height should be finite after predictions"
284 );
285
286 assert!(
288 (x_after - x_before).abs() < 5.0,
289 "x drift should be bounded, got delta={}",
290 (x_after - x_before).abs()
291 );
292 assert!(
293 (y_after - y_before).abs() < 5.0,
294 "y drift should be bounded, got delta={}",
295 (y_after - y_before).abs()
296 );
297 }
298
299 #[test]
300 fn test_numerical_stability_1000_cycles() {
301 let mut t = ConstantVelocityXYAHModel2::new(&[0.5, 0.5, 1.0, 0.5], 0.25);
302
303 for _ in 0..1000 {
305 t.predict();
306 }
307
308 for i in 0..8 {
310 let val: f32 = t.mean[i];
311 assert!(
312 val.is_finite(),
313 "mean[{i}] should be finite after 1000 predictions, got {val}",
314 );
315 }
316
317 for r in 0..8 {
319 for c in 0..8 {
320 let val: f32 = t.covariance[(r, c)];
321 assert!(
322 val.is_finite(),
323 "covariance[({r},{c})] should be finite after 1000 predictions, got {val}",
324 );
325 }
326 }
327 }
328
329 #[test]
330 fn test_gating_distance_edge_cases() {
331 let mut t = ConstantVelocityXYAHModel2::new(&[0.5, 0.5, 1.0, 0.5], 0.25);
332 for _ in 0..3 {
334 t.predict();
335 t.update(&[0.5, 0.5, 1.0, 0.5]);
336 }
337 t.predict();
338
339 let (projected_mean, _) = t.project();
341 let mut meas_close = OMatrix::<f32, Dyn, U4>::from_element(1, 0.0);
342 meas_close
343 .row_mut(0)
344 .copy_from_slice(projected_mean.as_slice());
345
346 let dist_close = t.gating_distance(&meas_close, false, GatingDistanceMetric::Mahalanobis);
347 assert!(
348 dist_close[0].is_finite(),
349 "Close-measurement distance should be finite"
350 );
351 assert!(
352 dist_close[0] < 1.0,
353 "Distance for exact-match measurement should be near 0, got {}",
354 dist_close[0]
355 );
356
357 let mut meas_far = OMatrix::<f32, Dyn, U4>::from_element(1, 0.0);
359 meas_far.copy_from_slice(&[10.0, 10.0, 5.0, 10.0]);
360
361 let dist_far = t.gating_distance(&meas_far, false, GatingDistanceMetric::Mahalanobis);
362 assert!(
363 dist_far[0].is_finite(),
364 "Far-measurement distance should be finite"
365 );
366 assert!(
367 dist_far[0] > dist_close[0],
368 "Far measurement should have larger distance than close one: {} vs {}",
369 dist_far[0],
370 dist_close[0]
371 );
372 }
373
374 #[test]
375 fn test_update_moves_mean_toward_measurement() {
376 let mut t = ConstantVelocityXYAHModel2::new(&[0.5, 0.5, 1.0, 0.5], 0.25);
377 t.predict();
378
379 let x_before: f32 = t.mean[0];
380 t.update(&[0.6, 0.5, 1.0, 0.5]);
382 let x_after: f32 = t.mean[0];
383
384 assert!(
385 x_after > x_before,
386 "Mean x should move toward the measurement (0.6), was {x_before}, now {x_after}"
387 );
388 assert!(
389 x_after <= 0.6,
390 "Mean x should not overshoot the measurement, got {x_after}"
391 );
392 }
393
394 #[test]
395 fn test_covariance_positive_diagonal() {
396 let t = ConstantVelocityXYAHModel2::new(&[0.5, 0.5, 1.0, 0.5], 0.25);
397
398 for i in 0..8 {
400 let val: f32 = t.covariance[(i, i)];
401 assert!(
402 val > 0.0,
403 "Covariance diagonal[{i}] should be positive, got {val}"
404 );
405 }
406 }
407
408 #[test]
409 fn test_predict_increases_uncertainty() {
410 let mut t = ConstantVelocityXYAHModel2::new(&[0.5, 0.5, 1.0, 0.5], 0.25);
411
412 let cov_before: f32 = t.covariance[(0, 0)];
413 t.predict();
414 let cov_after: f32 = t.covariance[(0, 0)];
415
416 assert!(
417 cov_after > cov_before,
418 "Predict should increase position uncertainty: {cov_before} -> {cov_after}"
419 );
420 }
421
422 #[test]
423 fn test_update_decreases_uncertainty() {
424 let mut t = ConstantVelocityXYAHModel2::new(&[0.5, 0.5, 1.0, 0.5], 0.25);
425 t.predict();
426
427 let cov_before: f32 = t.covariance[(0, 0)];
428 t.update(&[0.5, 0.5, 1.0, 0.5]);
429 let cov_after: f32 = t.covariance[(0, 0)];
430
431 assert!(
432 cov_after < cov_before,
433 "Update should decrease position uncertainty: {cov_before} -> {cov_after}"
434 );
435 }
436
437 #[test]
438 fn test_gating_distance_gaussian_vs_mahalanobis() {
439 let mut t = ConstantVelocityXYAHModel2::new(&[0.5, 0.5, 1.0, 0.5], 0.25);
440 for _ in 0..3 {
441 t.predict();
442 t.update(&[0.5, 0.5, 1.0, 0.5]);
443 }
444 t.predict();
445
446 let mut measurements = OMatrix::<f32, Dyn, U4>::from_element(1, 0.0);
447 measurements.copy_from_slice(&[0.6, 0.5, 1.0, 0.5]);
448
449 let dist_gauss = t.gating_distance(&measurements, false, GatingDistanceMetric::Gaussian);
450 let dist_maha = t.gating_distance(&measurements, false, GatingDistanceMetric::Mahalanobis);
451
452 assert!(dist_gauss[0].is_finite());
453 assert!(dist_maha[0].is_finite());
454
455 assert!(
457 dist_gauss[0] > 0.0,
458 "Gaussian distance should be > 0 for offset measurement"
459 );
460 assert!(
461 dist_maha[0] > 0.0,
462 "Mahalanobis distance should be > 0 for offset measurement"
463 );
464 }
465
466 #[test]
467 fn test_gating_distance_multiple_measurements() {
468 let mut t = ConstantVelocityXYAHModel2::new(&[0.5, 0.5, 1.0, 0.5], 0.25);
469 t.predict();
470 t.update(&[0.5, 0.5, 1.0, 0.5]);
471 t.predict();
472
473 let mut measurements = OMatrix::<f32, Dyn, U4>::from_element(2, 0.0);
475 measurements
476 .row_mut(0)
477 .copy_from_slice(&[0.5, 0.5, 1.0, 0.5]); measurements
479 .row_mut(1)
480 .copy_from_slice(&[5.0, 5.0, 1.0, 0.5]); let dists = t.gating_distance(&measurements, false, GatingDistanceMetric::Mahalanobis);
483 assert_eq!(dists.len(), 2, "Should return one distance per measurement");
484 assert!(dists[0].is_finite());
485 assert!(dists[1].is_finite());
486 assert!(
487 dists[1] > dists[0],
488 "Far measurement should have larger distance: {} vs {}",
489 dists[1],
490 dists[0]
491 );
492 }
493
494 #[test]
495 fn test_initiate_mean_matches_measurement() {
496 let measurement = [0.3, 0.7, 1.5, 2.0];
497 let t = ConstantVelocityXYAHModel2::new(&measurement, 0.25);
498
499 let x: f32 = t.mean[0];
501 let y: f32 = t.mean[1];
502 let a: f32 = t.mean[2];
503 let h: f32 = t.mean[3];
504 assert!((x - 0.3).abs() < 1e-6, "Mean x should be 0.3, got {x}");
505 assert!((y - 0.7).abs() < 1e-6, "Mean y should be 0.7, got {y}");
506 assert!((a - 1.5).abs() < 1e-6, "Mean a should be 1.5, got {a}");
507 assert!((h - 2.0).abs() < 1e-6, "Mean h should be 2.0, got {h}");
508
509 for i in 4..8 {
511 let v: f32 = t.mean[i];
512 assert!((v).abs() < 1e-6, "Velocity mean[{i}] should be 0, got {v}");
513 }
514 }
515}