1#[cfg(feature = "python")]
8use std::sync::Arc;
9
10use super::functions::{frobenius, iou, mean_euclidean, mean_manhattan};
11use super::scalar::ScalarDistance;
12use super::scipy_wrapper::ScipyDistance;
13use super::traits::Distance;
14use super::vectorized::VectorizedDistance;
15use crate::{Detection, TrackedObject};
16use nalgebra::DMatrix;
17
18#[cfg(feature = "python")]
22pub type CustomDistanceFn =
23 Arc<dyn Fn(&[&TrackedObject], &[&Detection]) -> DMatrix<f64> + Send + Sync>;
24
25#[cfg(feature = "python")]
27#[derive(Clone)]
28pub struct CustomDistance {
29 func: CustomDistanceFn,
30}
31
32#[cfg(feature = "python")]
33impl CustomDistance {
34 pub fn new<F>(f: F) -> Self
36 where
37 F: Fn(&[&TrackedObject], &[&Detection]) -> DMatrix<f64> + Send + Sync + 'static,
38 {
39 Self { func: Arc::new(f) }
40 }
41
42 #[inline]
44 pub fn get_distances(
45 &self,
46 objects: &[&TrackedObject],
47 candidates: &[&Detection],
48 ) -> DMatrix<f64> {
49 (self.func)(objects, candidates)
50 }
51
52 #[inline]
60 pub fn get_distances_objects(
61 &self,
62 objects: &[&TrackedObject],
63 candidates: &[&TrackedObject],
64 ) -> DMatrix<f64> {
65 let temp_detections: Vec<Detection> = candidates
67 .iter()
68 .map(|obj| Detection {
69 points: obj.estimate.clone(),
70 scores: None,
71 label: obj.label.clone(),
72 embedding: None,
73 data: None,
74 absolute_points: Some(obj.estimate.clone()),
75 age: Some(obj.age),
76 })
77 .collect();
78
79 let det_refs: Vec<&Detection> = temp_detections.iter().collect();
80 (self.func)(objects, &det_refs)
81 }
82}
83
84#[derive(Clone)]
90pub enum DistanceFunction {
91 Frobenius(ScalarDistance),
93 MeanEuclidean(ScalarDistance),
94 MeanManhattan(ScalarDistance),
95
96 Iou(VectorizedDistance),
98
99 ScipyEuclidean(ScipyDistance),
101 ScipySqeuclidean(ScipyDistance),
102 ScipyManhattan(ScipyDistance),
103 ScipyCosine(ScipyDistance),
104 ScipyChebyshev(ScipyDistance),
105
106 #[cfg(feature = "python")]
109 Custom(CustomDistance),
110}
111
112impl DistanceFunction {
113 #[inline(always)]
115 pub fn get_distances(
116 &self,
117 objects: &[&TrackedObject],
118 candidates: &[&Detection],
119 ) -> DMatrix<f64> {
120 match self {
121 DistanceFunction::Frobenius(d) => d.get_distances(objects, candidates),
123 DistanceFunction::MeanEuclidean(d) => d.get_distances(objects, candidates),
124 DistanceFunction::MeanManhattan(d) => d.get_distances(objects, candidates),
125
126 DistanceFunction::Iou(d) => d.get_distances(objects, candidates),
128
129 DistanceFunction::ScipyEuclidean(d) => d.get_distances(objects, candidates),
131 DistanceFunction::ScipySqeuclidean(d) => d.get_distances(objects, candidates),
132 DistanceFunction::ScipyManhattan(d) => d.get_distances(objects, candidates),
133 DistanceFunction::ScipyCosine(d) => d.get_distances(objects, candidates),
134 DistanceFunction::ScipyChebyshev(d) => d.get_distances(objects, candidates),
135
136 #[cfg(feature = "python")]
138 DistanceFunction::Custom(d) => d.get_distances(objects, candidates),
139 }
140 }
141
142 #[inline(always)]
148 pub fn get_distances_objects(
149 &self,
150 objects: &[&TrackedObject],
151 candidates: &[&TrackedObject],
152 ) -> DMatrix<f64> {
153 let temp_detections: Vec<Detection> = candidates
156 .iter()
157 .map(|obj| Detection {
158 points: obj.estimate.clone(),
159 scores: None,
160 label: obj.label.clone(),
161 embedding: None,
162 data: None,
163 absolute_points: Some(obj.estimate.clone()),
164 age: Some(obj.age),
165 })
166 .collect();
167
168 let det_refs: Vec<&Detection> = temp_detections.iter().collect();
169
170 match self {
171 #[cfg(feature = "python")]
174 DistanceFunction::Custom(d) => d.get_distances_objects(objects, candidates),
175
176 _ => self.get_distances(objects, &det_refs),
178 }
179 }
180}
181
182pub fn distance_function_by_name(name: &str) -> DistanceFunction {
189 match name {
190 "frobenius" => DistanceFunction::Frobenius(ScalarDistance::new(frobenius)),
192 "mean_euclidean" => DistanceFunction::MeanEuclidean(ScalarDistance::new(mean_euclidean)),
193 "mean_manhattan" => DistanceFunction::MeanManhattan(ScalarDistance::new(mean_manhattan)),
194
195 "iou" => DistanceFunction::Iou(VectorizedDistance::new(iou)),
197
198 "euclidean" => DistanceFunction::ScipyEuclidean(ScipyDistance::new("euclidean")),
200 "sqeuclidean" => DistanceFunction::ScipySqeuclidean(ScipyDistance::new("sqeuclidean")),
201 "manhattan" | "cityblock" => {
202 DistanceFunction::ScipyManhattan(ScipyDistance::new("manhattan"))
203 }
204 "cosine" => DistanceFunction::ScipyCosine(ScipyDistance::new("cosine")),
205 "chebyshev" => DistanceFunction::ScipyChebyshev(ScipyDistance::new("chebyshev")),
206
207 _ => panic!("Unknown distance function: {}", name),
208 }
209}
210
211pub fn try_distance_function_by_name(name: &str) -> Result<DistanceFunction, String> {
215 match name {
216 "frobenius" => Ok(DistanceFunction::Frobenius(ScalarDistance::new(frobenius))),
218 "mean_euclidean" => Ok(DistanceFunction::MeanEuclidean(ScalarDistance::new(mean_euclidean))),
219 "mean_manhattan" => Ok(DistanceFunction::MeanManhattan(ScalarDistance::new(mean_manhattan))),
220
221 "iou" => Ok(DistanceFunction::Iou(VectorizedDistance::new(iou))),
223
224 "euclidean" => Ok(DistanceFunction::ScipyEuclidean(ScipyDistance::new("euclidean"))),
226 "sqeuclidean" => Ok(DistanceFunction::ScipySqeuclidean(ScipyDistance::new("sqeuclidean"))),
227 "manhattan" | "cityblock" => Ok(DistanceFunction::ScipyManhattan(ScipyDistance::new("manhattan"))),
228 "cosine" => Ok(DistanceFunction::ScipyCosine(ScipyDistance::new("cosine"))),
229 "chebyshev" => Ok(DistanceFunction::ScipyChebyshev(ScipyDistance::new("chebyshev"))),
230
231 _ => Err(format!("Unknown distance function: {}. Supported: frobenius, mean_euclidean, mean_manhattan, iou, euclidean, sqeuclidean, manhattan, cityblock, cosine, chebyshev", name)),
232 }
233}
234
235impl Distance for DistanceFunction {
237 #[inline(always)]
238 fn get_distances(&self, objects: &[&TrackedObject], candidates: &[&Detection]) -> DMatrix<f64> {
239 DistanceFunction::get_distances(self, objects, candidates)
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247
248 fn create_mock_detection(points: &[f64], rows: usize, cols: usize) -> Detection {
249 Detection {
250 points: DMatrix::from_row_slice(rows, cols, points),
251 scores: None,
252 label: None,
253 embedding: None,
254 data: None,
255 absolute_points: None,
256 age: None,
257 }
258 }
259
260 fn create_mock_tracked_object(estimate: &[f64], rows: usize, cols: usize) -> TrackedObject {
261 let estimate_matrix = DMatrix::from_row_slice(rows, cols, estimate);
262 TrackedObject {
263 id: Some(0),
264 global_id: 0,
265 initializing_id: None,
266 age: 0,
267 hit_counter: 1,
268 point_hit_counter: vec![1; rows],
269 last_detection: None,
270 last_distance: None,
271 current_min_distance: None,
272 past_detections: std::collections::VecDeque::new(),
273 label: None,
274 reid_hit_counter: None,
275 estimate: estimate_matrix.clone(),
276 estimate_velocity: DMatrix::zeros(rows, cols),
277 is_initializing: false,
278 detected_at_least_once_points: vec![true; rows],
279 filter: crate::filter::FilterEnum::None(crate::filter::NoFilter::new(&estimate_matrix)),
280 initial_period: 1,
281 num_points: rows,
282 dim_points: cols,
283 last_coord_transform: None,
284 }
285 }
286
287 #[test]
288 fn test_distance_function_frobenius() {
289 let distance = distance_function_by_name("frobenius");
290 let det = create_mock_detection(&[1.0, 2.0, 3.0, 4.0], 2, 2);
291 let obj = create_mock_tracked_object(&[1.0, 2.0, 3.0, 4.0], 2, 2);
292 let matrix = distance.get_distances(&[&obj], &[&det]);
293 assert!((matrix[(0, 0)] - 0.0).abs() < 1e-6);
294 }
295
296 #[test]
297 fn test_distance_function_iou() {
298 let distance = distance_function_by_name("iou");
299 let det = create_mock_detection(&[0.0, 0.0, 1.0, 1.0], 1, 4);
300 let obj = create_mock_tracked_object(&[0.0, 0.0, 1.0, 1.0], 1, 4);
301 let matrix = distance.get_distances(&[&obj], &[&det]);
302 assert!((matrix[(0, 0)] - 0.0).abs() < 1e-6);
303 }
304
305 #[test]
306 fn test_distance_function_euclidean() {
307 let distance = distance_function_by_name("euclidean");
308 let det = create_mock_detection(&[1.0, 2.0], 1, 2);
309 let obj = create_mock_tracked_object(&[1.0, 2.0], 1, 2);
310 let matrix = distance.get_distances(&[&obj], &[&det]);
311 assert!((matrix[(0, 0)] - 0.0).abs() < 1e-6);
312 }
313
314 #[test]
315 #[should_panic(expected = "Unknown distance function")]
316 fn test_distance_function_invalid() {
317 distance_function_by_name("invalid_distance");
318 }
319
320 #[cfg(feature = "python")]
323 #[test]
324 fn test_custom_distance_basic() {
325 use std::sync::Arc;
326
327 let custom = CustomDistance::new(|objects, candidates| {
329 let n_cands = candidates.len();
330 let n_objs = objects.len();
331 let mut matrix = DMatrix::zeros(n_cands, n_objs);
332
333 for (c, cand) in candidates.iter().enumerate() {
334 for (o, obj) in objects.iter().enumerate() {
335 let det_point = cand.points.row(0);
337 let obj_point = obj.estimate.row(0);
338 let diff: f64 = det_point
339 .iter()
340 .zip(obj_point.iter())
341 .map(|(a, b)| (a - b).powi(2))
342 .sum();
343 matrix[(c, o)] = diff.sqrt();
344 }
345 }
346 matrix
347 });
348
349 let det = create_mock_detection(&[1.0, 2.0], 1, 2);
350 let obj = create_mock_tracked_object(&[1.0, 2.0], 1, 2);
351
352 let matrix = custom.get_distances(&[&obj], &[&det]);
353 assert!(
354 (matrix[(0, 0)] - 0.0).abs() < 1e-6,
355 "Perfect match should have distance 0"
356 );
357 }
358
359 #[cfg(feature = "python")]
360 #[test]
361 fn test_custom_distance_nonzero() {
362 let custom = CustomDistance::new(|objects, candidates| {
364 let n_cands = candidates.len();
365 let n_objs = objects.len();
366 let mut matrix = DMatrix::zeros(n_cands, n_objs);
367 for c in 0..n_cands {
368 for o in 0..n_objs {
369 matrix[(c, o)] = 42.0; }
371 }
372 matrix
373 });
374
375 let det = create_mock_detection(&[1.0, 2.0], 1, 2);
376 let obj = create_mock_tracked_object(&[100.0, 200.0], 1, 2);
377
378 let matrix = custom.get_distances(&[&obj], &[&det]);
379 assert!(
380 (matrix[(0, 0)] - 42.0).abs() < 1e-6,
381 "Should return fixed value 42"
382 );
383 }
384
385 #[cfg(feature = "python")]
386 #[test]
387 fn test_custom_distance_multiple_objects_and_detections() {
388 let custom = CustomDistance::new(|objects, candidates| {
390 let n_cands = candidates.len();
391 let n_objs = objects.len();
392 let mut matrix = DMatrix::zeros(n_cands, n_objs);
393 for c in 0..n_cands {
394 for o in 0..n_objs {
395 matrix[(c, o)] = (c + o) as f64;
396 }
397 }
398 matrix
399 });
400
401 let det1 = create_mock_detection(&[1.0, 1.0], 1, 2);
402 let det2 = create_mock_detection(&[2.0, 2.0], 1, 2);
403 let obj1 = create_mock_tracked_object(&[10.0, 10.0], 1, 2);
404 let obj2 = create_mock_tracked_object(&[20.0, 20.0], 1, 2);
405
406 let matrix = custom.get_distances(&[&obj1, &obj2], &[&det1, &det2]);
407
408 assert_eq!(matrix.nrows(), 2);
410 assert_eq!(matrix.ncols(), 2);
411
412 assert!((matrix[(0, 0)] - 0.0).abs() < 1e-6); assert!((matrix[(0, 1)] - 1.0).abs() < 1e-6); assert!((matrix[(1, 0)] - 1.0).abs() < 1e-6); assert!((matrix[(1, 1)] - 2.0).abs() < 1e-6); }
418
419 #[cfg(feature = "python")]
420 #[test]
421 fn test_distance_function_custom_variant() {
422 let custom = CustomDistance::new(|_objects, _candidates| DMatrix::from_element(1, 1, 5.5));
424
425 let distance = DistanceFunction::Custom(custom);
426
427 let det = create_mock_detection(&[1.0, 2.0], 1, 2);
428 let obj = create_mock_tracked_object(&[1.0, 2.0], 1, 2);
429
430 let matrix = distance.get_distances(&[&obj], &[&det]);
431 assert!(
432 (matrix[(0, 0)] - 5.5).abs() < 1e-6,
433 "Custom distance should return 5.5"
434 );
435 }
436
437 #[cfg(feature = "python")]
438 #[test]
439 fn test_custom_distance_clone() {
440 let custom = CustomDistance::new(|_objects, _candidates| DMatrix::from_element(1, 1, 7.0));
442
443 let custom_clone = custom.clone();
444
445 let det = create_mock_detection(&[1.0, 2.0], 1, 2);
446 let obj = create_mock_tracked_object(&[1.0, 2.0], 1, 2);
447
448 let matrix1 = custom.get_distances(&[&obj], &[&det]);
450 let matrix2 = custom_clone.get_distances(&[&obj], &[&det]);
451
452 assert!((matrix1[(0, 0)] - 7.0).abs() < 1e-6);
453 assert!((matrix2[(0, 0)] - 7.0).abs() < 1e-6);
454 }
455
456 #[cfg(feature = "python")]
457 #[test]
458 fn test_distance_function_custom_clone() {
459 let custom = CustomDistance::new(|_objects, _candidates| DMatrix::from_element(1, 1, 3.14));
461
462 let distance = DistanceFunction::Custom(custom);
463 let distance_clone = distance.clone();
464
465 let det = create_mock_detection(&[1.0, 2.0], 1, 2);
466 let obj = create_mock_tracked_object(&[1.0, 2.0], 1, 2);
467
468 let matrix1 = distance.get_distances(&[&obj], &[&det]);
469 let matrix2 = distance_clone.get_distances(&[&obj], &[&det]);
470
471 assert!((matrix1[(0, 0)] - 3.14).abs() < 1e-6);
472 assert!((matrix2[(0, 0)] - 3.14).abs() < 1e-6);
473 }
474}