norfair_rs/distances/
dispatch.rs

1//! Enum-based distance dispatch for static (non-virtual) function calls.
2//!
3//! This module provides `DistanceFunction`, an enum that wraps all supported
4//! distance types and dispatches without vtable lookups, improving performance
5//! for hot-path code.
6
7#[cfg(feature = "python")]
8use std::sync::Arc;
9
10use super::functions::{frobenius, iou, mean_euclidean, mean_manhattan};
11use super::scalar::ScalarDistance;
12use super::scipy_wrapper::ScipyDistance;
13use super::traits::Distance;
14use super::vectorized::VectorizedDistance;
15use crate::{Detection, TrackedObject};
16use nalgebra::DMatrix;
17
18/// Custom distance function type for Python callbacks.
19///
20/// Uses `Arc` to allow cloning while sharing the underlying function.
21#[cfg(feature = "python")]
22pub type CustomDistanceFn =
23    Arc<dyn Fn(&[&TrackedObject], &[&Detection]) -> DMatrix<f64> + Send + Sync>;
24
25/// Wrapper for custom distance functions (e.g., Python callables).
26#[cfg(feature = "python")]
27#[derive(Clone)]
28pub struct CustomDistance {
29    func: CustomDistanceFn,
30}
31
32#[cfg(feature = "python")]
33impl CustomDistance {
34    /// Create a new custom distance from a function.
35    pub fn new<F>(f: F) -> Self
36    where
37        F: Fn(&[&TrackedObject], &[&Detection]) -> DMatrix<f64> + Send + Sync + 'static,
38    {
39        Self { func: Arc::new(f) }
40    }
41
42    /// Get distances between objects and candidates.
43    #[inline]
44    pub fn get_distances(
45        &self,
46        objects: &[&TrackedObject],
47        candidates: &[&Detection],
48    ) -> DMatrix<f64> {
49        (self.func)(objects, candidates)
50    }
51
52    /// Get distances between two sets of TrackedObjects (for ReID matching).
53    ///
54    /// This creates temporary Detections from the candidate TrackedObjects.
55    /// The underlying Python callback will receive (Detection, TrackedObject).
56    ///
57    /// Note: For Python callables that expect (TrackedObject, TrackedObject),
58    /// this won't work correctly. A separate reid-specific callback type would be needed.
59    #[inline]
60    pub fn get_distances_objects(
61        &self,
62        objects: &[&TrackedObject],
63        candidates: &[&TrackedObject],
64    ) -> DMatrix<f64> {
65        // Create temporary detections from candidate estimates
66        let temp_detections: Vec<Detection> = candidates
67            .iter()
68            .map(|obj| Detection {
69                points: obj.estimate.clone(),
70                scores: None,
71                label: obj.label.clone(),
72                embedding: None,
73                data: None,
74                absolute_points: Some(obj.estimate.clone()),
75                age: Some(obj.age),
76            })
77            .collect();
78
79        let det_refs: Vec<&Detection> = temp_detections.iter().collect();
80        (self.func)(objects, &det_refs)
81    }
82}
83
84/// Enum-based distance function for static dispatch.
85///
86/// This avoids `Box<dyn Distance>` vtable overhead by using an enum
87/// with inline implementations. Use `distance_function_by_name()` to
88/// create instances.
89#[derive(Clone)]
90pub enum DistanceFunction {
91    // Scalar distance functions
92    Frobenius(ScalarDistance),
93    MeanEuclidean(ScalarDistance),
94    MeanManhattan(ScalarDistance),
95
96    // Vectorized distance functions
97    Iou(VectorizedDistance),
98
99    // Scipy-style distance functions
100    ScipyEuclidean(ScipyDistance),
101    ScipySqeuclidean(ScipyDistance),
102    ScipyManhattan(ScipyDistance),
103    ScipyCosine(ScipyDistance),
104    ScipyChebyshev(ScipyDistance),
105
106    /// Custom distance function (used for Python callables).
107    /// Only available with the "python" feature.
108    #[cfg(feature = "python")]
109    Custom(CustomDistance),
110}
111
112impl DistanceFunction {
113    /// Get distances between objects and candidates.
114    #[inline(always)]
115    pub fn get_distances(
116        &self,
117        objects: &[&TrackedObject],
118        candidates: &[&Detection],
119    ) -> DMatrix<f64> {
120        match self {
121            // Scalar functions
122            DistanceFunction::Frobenius(d) => d.get_distances(objects, candidates),
123            DistanceFunction::MeanEuclidean(d) => d.get_distances(objects, candidates),
124            DistanceFunction::MeanManhattan(d) => d.get_distances(objects, candidates),
125
126            // Vectorized functions
127            DistanceFunction::Iou(d) => d.get_distances(objects, candidates),
128
129            // Scipy functions
130            DistanceFunction::ScipyEuclidean(d) => d.get_distances(objects, candidates),
131            DistanceFunction::ScipySqeuclidean(d) => d.get_distances(objects, candidates),
132            DistanceFunction::ScipyManhattan(d) => d.get_distances(objects, candidates),
133            DistanceFunction::ScipyCosine(d) => d.get_distances(objects, candidates),
134            DistanceFunction::ScipyChebyshev(d) => d.get_distances(objects, candidates),
135
136            // Custom distance function (Python callables)
137            #[cfg(feature = "python")]
138            DistanceFunction::Custom(d) => d.get_distances(objects, candidates),
139        }
140    }
141
142    /// Get distances between two sets of TrackedObjects (for ReID matching).
143    ///
144    /// For built-in distance functions, creates temporary Detections from candidate estimates.
145    /// For custom Python callables, this requires the reid_distance_function to accept
146    /// (TrackedObject, TrackedObject) -> float (not Detection, TrackedObject).
147    #[inline(always)]
148    pub fn get_distances_objects(
149        &self,
150        objects: &[&TrackedObject],
151        candidates: &[&TrackedObject],
152    ) -> DMatrix<f64> {
153        // For built-in functions, create temporary detections from candidate estimates
154        // and use the standard distance computation
155        let temp_detections: Vec<Detection> = candidates
156            .iter()
157            .map(|obj| Detection {
158                points: obj.estimate.clone(),
159                scores: None,
160                label: obj.label.clone(),
161                embedding: None,
162                data: None,
163                absolute_points: Some(obj.estimate.clone()),
164                age: Some(obj.age),
165            })
166            .collect();
167
168        let det_refs: Vec<&Detection> = temp_detections.iter().collect();
169
170        match self {
171            // For Custom (Python callback), we need special handling
172            // The callback expects (TrackedObject, TrackedObject), not (Detection, TrackedObject)
173            #[cfg(feature = "python")]
174            DistanceFunction::Custom(d) => d.get_distances_objects(objects, candidates),
175
176            // For all built-in functions, use the standard detection-based distance
177            _ => self.get_distances(objects, &det_refs),
178        }
179    }
180}
181
182/// Create a DistanceFunction enum by name (static dispatch version).
183///
184/// This is the preferred way to create distance functions for performance-critical code.
185///
186/// # Panics
187/// Panics if the distance name is not recognized.
188pub fn distance_function_by_name(name: &str) -> DistanceFunction {
189    match name {
190        // Scalar functions
191        "frobenius" => DistanceFunction::Frobenius(ScalarDistance::new(frobenius)),
192        "mean_euclidean" => DistanceFunction::MeanEuclidean(ScalarDistance::new(mean_euclidean)),
193        "mean_manhattan" => DistanceFunction::MeanManhattan(ScalarDistance::new(mean_manhattan)),
194
195        // Vectorized functions
196        "iou" => DistanceFunction::Iou(VectorizedDistance::new(iou)),
197
198        // Scipy functions
199        "euclidean" => DistanceFunction::ScipyEuclidean(ScipyDistance::new("euclidean")),
200        "sqeuclidean" => DistanceFunction::ScipySqeuclidean(ScipyDistance::new("sqeuclidean")),
201        "manhattan" | "cityblock" => {
202            DistanceFunction::ScipyManhattan(ScipyDistance::new("manhattan"))
203        }
204        "cosine" => DistanceFunction::ScipyCosine(ScipyDistance::new("cosine")),
205        "chebyshev" => DistanceFunction::ScipyChebyshev(ScipyDistance::new("chebyshev")),
206
207        _ => panic!("Unknown distance function: {}", name),
208    }
209}
210
211/// Create a DistanceFunction enum by name, returning a Result instead of panicking.
212///
213/// This is useful for error handling when the distance name comes from user input.
214pub fn try_distance_function_by_name(name: &str) -> Result<DistanceFunction, String> {
215    match name {
216        // Scalar functions
217        "frobenius" => Ok(DistanceFunction::Frobenius(ScalarDistance::new(frobenius))),
218        "mean_euclidean" => Ok(DistanceFunction::MeanEuclidean(ScalarDistance::new(mean_euclidean))),
219        "mean_manhattan" => Ok(DistanceFunction::MeanManhattan(ScalarDistance::new(mean_manhattan))),
220
221        // Vectorized functions
222        "iou" => Ok(DistanceFunction::Iou(VectorizedDistance::new(iou))),
223
224        // Scipy functions
225        "euclidean" => Ok(DistanceFunction::ScipyEuclidean(ScipyDistance::new("euclidean"))),
226        "sqeuclidean" => Ok(DistanceFunction::ScipySqeuclidean(ScipyDistance::new("sqeuclidean"))),
227        "manhattan" | "cityblock" => Ok(DistanceFunction::ScipyManhattan(ScipyDistance::new("manhattan"))),
228        "cosine" => Ok(DistanceFunction::ScipyCosine(ScipyDistance::new("cosine"))),
229        "chebyshev" => Ok(DistanceFunction::ScipyChebyshev(ScipyDistance::new("chebyshev"))),
230
231        _ => Err(format!("Unknown distance function: {}. Supported: frobenius, mean_euclidean, mean_manhattan, iou, euclidean, sqeuclidean, manhattan, cityblock, cosine, chebyshev", name)),
232    }
233}
234
235// Implement the Distance trait for DistanceFunction so it can be used interchangeably
236impl Distance for DistanceFunction {
237    #[inline(always)]
238    fn get_distances(&self, objects: &[&TrackedObject], candidates: &[&Detection]) -> DMatrix<f64> {
239        // Delegate to the inherent method
240        DistanceFunction::get_distances(self, objects, candidates)
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247
248    fn create_mock_detection(points: &[f64], rows: usize, cols: usize) -> Detection {
249        Detection {
250            points: DMatrix::from_row_slice(rows, cols, points),
251            scores: None,
252            label: None,
253            embedding: None,
254            data: None,
255            absolute_points: None,
256            age: None,
257        }
258    }
259
260    fn create_mock_tracked_object(estimate: &[f64], rows: usize, cols: usize) -> TrackedObject {
261        let estimate_matrix = DMatrix::from_row_slice(rows, cols, estimate);
262        TrackedObject {
263            id: Some(0),
264            global_id: 0,
265            initializing_id: None,
266            age: 0,
267            hit_counter: 1,
268            point_hit_counter: vec![1; rows],
269            last_detection: None,
270            last_distance: None,
271            current_min_distance: None,
272            past_detections: std::collections::VecDeque::new(),
273            label: None,
274            reid_hit_counter: None,
275            estimate: estimate_matrix.clone(),
276            estimate_velocity: DMatrix::zeros(rows, cols),
277            is_initializing: false,
278            detected_at_least_once_points: vec![true; rows],
279            filter: crate::filter::FilterEnum::None(crate::filter::NoFilter::new(&estimate_matrix)),
280            initial_period: 1,
281            num_points: rows,
282            dim_points: cols,
283            last_coord_transform: None,
284        }
285    }
286
287    #[test]
288    fn test_distance_function_frobenius() {
289        let distance = distance_function_by_name("frobenius");
290        let det = create_mock_detection(&[1.0, 2.0, 3.0, 4.0], 2, 2);
291        let obj = create_mock_tracked_object(&[1.0, 2.0, 3.0, 4.0], 2, 2);
292        let matrix = distance.get_distances(&[&obj], &[&det]);
293        assert!((matrix[(0, 0)] - 0.0).abs() < 1e-6);
294    }
295
296    #[test]
297    fn test_distance_function_iou() {
298        let distance = distance_function_by_name("iou");
299        let det = create_mock_detection(&[0.0, 0.0, 1.0, 1.0], 1, 4);
300        let obj = create_mock_tracked_object(&[0.0, 0.0, 1.0, 1.0], 1, 4);
301        let matrix = distance.get_distances(&[&obj], &[&det]);
302        assert!((matrix[(0, 0)] - 0.0).abs() < 1e-6);
303    }
304
305    #[test]
306    fn test_distance_function_euclidean() {
307        let distance = distance_function_by_name("euclidean");
308        let det = create_mock_detection(&[1.0, 2.0], 1, 2);
309        let obj = create_mock_tracked_object(&[1.0, 2.0], 1, 2);
310        let matrix = distance.get_distances(&[&obj], &[&det]);
311        assert!((matrix[(0, 0)] - 0.0).abs() < 1e-6);
312    }
313
314    #[test]
315    #[should_panic(expected = "Unknown distance function")]
316    fn test_distance_function_invalid() {
317        distance_function_by_name("invalid_distance");
318    }
319
320    // ===== CustomDistance Tests (Python feature only) =====
321
322    #[cfg(feature = "python")]
323    #[test]
324    fn test_custom_distance_basic() {
325        use std::sync::Arc;
326
327        // Create a simple custom distance function that returns euclidean distance
328        let custom = CustomDistance::new(|objects, candidates| {
329            let n_cands = candidates.len();
330            let n_objs = objects.len();
331            let mut matrix = DMatrix::zeros(n_cands, n_objs);
332
333            for (c, cand) in candidates.iter().enumerate() {
334                for (o, obj) in objects.iter().enumerate() {
335                    // Simple euclidean distance between first points
336                    let det_point = cand.points.row(0);
337                    let obj_point = obj.estimate.row(0);
338                    let diff: f64 = det_point
339                        .iter()
340                        .zip(obj_point.iter())
341                        .map(|(a, b)| (a - b).powi(2))
342                        .sum();
343                    matrix[(c, o)] = diff.sqrt();
344                }
345            }
346            matrix
347        });
348
349        let det = create_mock_detection(&[1.0, 2.0], 1, 2);
350        let obj = create_mock_tracked_object(&[1.0, 2.0], 1, 2);
351
352        let matrix = custom.get_distances(&[&obj], &[&det]);
353        assert!(
354            (matrix[(0, 0)] - 0.0).abs() < 1e-6,
355            "Perfect match should have distance 0"
356        );
357    }
358
359    #[cfg(feature = "python")]
360    #[test]
361    fn test_custom_distance_nonzero() {
362        // Custom distance that returns a fixed value
363        let custom = CustomDistance::new(|objects, candidates| {
364            let n_cands = candidates.len();
365            let n_objs = objects.len();
366            let mut matrix = DMatrix::zeros(n_cands, n_objs);
367            for c in 0..n_cands {
368                for o in 0..n_objs {
369                    matrix[(c, o)] = 42.0; // Fixed distance
370                }
371            }
372            matrix
373        });
374
375        let det = create_mock_detection(&[1.0, 2.0], 1, 2);
376        let obj = create_mock_tracked_object(&[100.0, 200.0], 1, 2);
377
378        let matrix = custom.get_distances(&[&obj], &[&det]);
379        assert!(
380            (matrix[(0, 0)] - 42.0).abs() < 1e-6,
381            "Should return fixed value 42"
382        );
383    }
384
385    #[cfg(feature = "python")]
386    #[test]
387    fn test_custom_distance_multiple_objects_and_detections() {
388        // Custom distance that returns row + col index
389        let custom = CustomDistance::new(|objects, candidates| {
390            let n_cands = candidates.len();
391            let n_objs = objects.len();
392            let mut matrix = DMatrix::zeros(n_cands, n_objs);
393            for c in 0..n_cands {
394                for o in 0..n_objs {
395                    matrix[(c, o)] = (c + o) as f64;
396                }
397            }
398            matrix
399        });
400
401        let det1 = create_mock_detection(&[1.0, 1.0], 1, 2);
402        let det2 = create_mock_detection(&[2.0, 2.0], 1, 2);
403        let obj1 = create_mock_tracked_object(&[10.0, 10.0], 1, 2);
404        let obj2 = create_mock_tracked_object(&[20.0, 20.0], 1, 2);
405
406        let matrix = custom.get_distances(&[&obj1, &obj2], &[&det1, &det2]);
407
408        // Matrix should be 2x2 (2 candidates x 2 objects)
409        assert_eq!(matrix.nrows(), 2);
410        assert_eq!(matrix.ncols(), 2);
411
412        // Check values: matrix[(c, o)] = c + o
413        assert!((matrix[(0, 0)] - 0.0).abs() < 1e-6); // c=0, o=0
414        assert!((matrix[(0, 1)] - 1.0).abs() < 1e-6); // c=0, o=1
415        assert!((matrix[(1, 0)] - 1.0).abs() < 1e-6); // c=1, o=0
416        assert!((matrix[(1, 1)] - 2.0).abs() < 1e-6); // c=1, o=1
417    }
418
419    #[cfg(feature = "python")]
420    #[test]
421    fn test_distance_function_custom_variant() {
422        // Test DistanceFunction::Custom variant works through the enum dispatch
423        let custom = CustomDistance::new(|_objects, _candidates| DMatrix::from_element(1, 1, 5.5));
424
425        let distance = DistanceFunction::Custom(custom);
426
427        let det = create_mock_detection(&[1.0, 2.0], 1, 2);
428        let obj = create_mock_tracked_object(&[1.0, 2.0], 1, 2);
429
430        let matrix = distance.get_distances(&[&obj], &[&det]);
431        assert!(
432            (matrix[(0, 0)] - 5.5).abs() < 1e-6,
433            "Custom distance should return 5.5"
434        );
435    }
436
437    #[cfg(feature = "python")]
438    #[test]
439    fn test_custom_distance_clone() {
440        // Test that CustomDistance can be cloned (via Arc)
441        let custom = CustomDistance::new(|_objects, _candidates| DMatrix::from_element(1, 1, 7.0));
442
443        let custom_clone = custom.clone();
444
445        let det = create_mock_detection(&[1.0, 2.0], 1, 2);
446        let obj = create_mock_tracked_object(&[1.0, 2.0], 1, 2);
447
448        // Both should return the same value
449        let matrix1 = custom.get_distances(&[&obj], &[&det]);
450        let matrix2 = custom_clone.get_distances(&[&obj], &[&det]);
451
452        assert!((matrix1[(0, 0)] - 7.0).abs() < 1e-6);
453        assert!((matrix2[(0, 0)] - 7.0).abs() < 1e-6);
454    }
455
456    #[cfg(feature = "python")]
457    #[test]
458    fn test_distance_function_custom_clone() {
459        // Test that DistanceFunction::Custom can be cloned
460        let custom = CustomDistance::new(|_objects, _candidates| DMatrix::from_element(1, 1, 3.14));
461
462        let distance = DistanceFunction::Custom(custom);
463        let distance_clone = distance.clone();
464
465        let det = create_mock_detection(&[1.0, 2.0], 1, 2);
466        let obj = create_mock_tracked_object(&[1.0, 2.0], 1, 2);
467
468        let matrix1 = distance.get_distances(&[&obj], &[&det]);
469        let matrix2 = distance_clone.get_distances(&[&obj], &[&det]);
470
471        assert!((matrix1[(0, 0)] - 3.14).abs() < 1e-6);
472        assert!((matrix2[(0, 0)] - 3.14).abs() < 1e-6);
473    }
474}