norfair_rs/
tracker.rs

1//! Main tracker implementation.
2
3use nalgebra::{DMatrix, DVector};
4use std::collections::VecDeque;
5
6use crate::camera_motion::CoordinateTransformation;
7use crate::distances::{distance_function_by_name, DistanceFunction};
8use crate::filter::FilterFactoryEnum;
9use crate::internal::numpy::to_row_major_vec;
10use crate::matching::{get_unmatched, match_detections_and_objects};
11use crate::tracked_object::get_next_global_id;
12use crate::{Detection, Error, Result, TrackedObject};
13
14/// Configuration for the tracker.
15#[derive(Clone)]
16pub struct TrackerConfig {
17    /// Distance function for matching detections to objects (enum-based static dispatch).
18    pub distance_function: DistanceFunction,
19
20    /// Maximum distance threshold for valid matches.
21    pub distance_threshold: f64,
22
23    /// Maximum hit counter value (frames to keep object alive without detections).
24    pub hit_counter_max: i32,
25
26    /// Frames before an object becomes "initialized" (gets permanent ID).
27    pub initialization_delay: i32,
28
29    /// Maximum hit counter for individual points.
30    pub pointwise_hit_counter_max: i32,
31
32    /// Minimum score for a detection point to be considered.
33    pub detection_threshold: f64,
34
35    /// Factory for creating Kalman filters (enum-based static dispatch).
36    pub filter_factory: FilterFactoryEnum,
37
38    /// Number of past detections to store for re-identification.
39    pub past_detections_length: usize,
40
41    /// Optional distance function for re-identification.
42    pub reid_distance_function: Option<DistanceFunction>,
43
44    /// Distance threshold for re-identification.
45    pub reid_distance_threshold: f64,
46
47    /// Maximum hit counter for re-identification phase.
48    pub reid_hit_counter_max: Option<i32>,
49}
50
51impl TrackerConfig {
52    /// Create a new tracker configuration with enum-based dispatch.
53    ///
54    /// # Arguments
55    /// * `distance_function` - Distance function for matching
56    /// * `distance_threshold` - Maximum match distance
57    pub fn new(distance_function: DistanceFunction, distance_threshold: f64) -> Self {
58        Self {
59            distance_function,
60            distance_threshold,
61            hit_counter_max: 15,
62            initialization_delay: -1, // Will be set to hit_counter_max / 2
63            pointwise_hit_counter_max: 4,
64            detection_threshold: 0.0,
65            filter_factory: FilterFactoryEnum::default(),
66            past_detections_length: 4,
67            reid_distance_function: None,
68            reid_distance_threshold: 1.0,
69            reid_hit_counter_max: None,
70        }
71    }
72
73    /// Create configuration from a distance function name.
74    pub fn from_distance_name(name: &str, distance_threshold: f64) -> Self {
75        Self::new(distance_function_by_name(name), distance_threshold)
76    }
77}
78
79/// Object tracker.
80///
81/// Maintains a set of tracked objects across frames, matching new detections
82/// to existing objects and managing object lifecycles.
83pub struct Tracker {
84    /// Tracker configuration.
85    pub config: TrackerConfig,
86
87    /// Currently tracked objects.
88    pub tracked_objects: Vec<TrackedObject>,
89
90    /// Local instance ID counter.
91    instance_id_counter: i32,
92
93    /// Local initializing ID counter.
94    initializing_id_counter: i32,
95}
96
97impl Tracker {
98    /// Create a new tracker with the given configuration.
99    pub fn new(mut config: TrackerConfig) -> Result<Self> {
100        // Validate and set defaults
101        if config.initialization_delay == -1 {
102            config.initialization_delay = config.hit_counter_max / 2;
103        }
104
105        if config.initialization_delay < 0 {
106            return Err(Error::InvalidConfig(
107                "initialization_delay must be non-negative".to_string(),
108            ));
109        }
110
111        if config.initialization_delay >= config.hit_counter_max {
112            return Err(Error::InvalidConfig(
113                "initialization_delay must be less than hit_counter_max".to_string(),
114            ));
115        }
116
117        Ok(Self {
118            config,
119            tracked_objects: Vec::new(),
120            // Start at 1 to match Python/Go behavior (IDs are 1-indexed)
121            instance_id_counter: 1,
122            initializing_id_counter: 1,
123        })
124    }
125
126    /// Update the tracker with new detections.
127    ///
128    /// # Arguments
129    /// * `detections` - New detections for this frame
130    /// * `period` - Frame period (for hit counter increment)
131    /// * `coord_transform` - Optional coordinate transformation for camera motion
132    ///
133    /// # Returns
134    /// Slice of active (non-initializing) tracked objects
135    pub fn update(
136        &mut self,
137        mut detections: Vec<Detection>,
138        period: i32,
139        coord_transform: Option<&dyn CoordinateTransformation>,
140    ) -> Vec<&TrackedObject> {
141        // Apply coordinate transformation to detections
142        if let Some(transform) = coord_transform {
143            for det in &mut detections {
144                let abs_points = transform.rel_to_abs(&det.points);
145                det.set_absolute_points(abs_points);
146            }
147        }
148
149        // STAGE 2: Remove dead objects BEFORE predict step (matches Python/Go behavior)
150        // Also categorize objects BEFORE decrement (Python categorizes before tracker_step)
151        // With ReID: objects survive while reid_hit_counter >= 0, separate into alive/dead
152        // Without ReID: objects with hit_counter < 0 are removed
153        let dead_indices: Vec<usize> = if self.config.reid_hit_counter_max.is_none() {
154            // No ReID: remove dead objects (hit_counter < 0)
155            self.tracked_objects
156                .retain(|obj| obj.hit_counter_is_positive());
157            vec![] // No dead objects to track
158        } else {
159            // With ReID: keep objects with reid_hit_counter >= 0
160            self.tracked_objects
161                .retain(|obj| obj.reid_hit_counter_is_positive());
162            // Collect indices of dead objects (hit_counter < 0 but reid_hit_counter >= 0)
163            self.tracked_objects
164                .iter()
165                .enumerate()
166                .filter(|(_, obj)| !obj.hit_counter_is_positive())
167                .map(|(i, _)| i)
168                .collect()
169        };
170
171        // IMPORTANT: Categorize objects BEFORE predict step (Python does this before tracker_step)
172        // This means objects with hit_counter=0 are still considered "alive" for matching this frame
173        // - alive_initialized: hit_counter >= 0, not initializing (participate in regular matching)
174        // - initializing: initializing objects (participate in init matching)
175        // - dead: hit_counter < 0 (only participate in ReID matching via dead_indices computed above)
176        let alive_initialized_indices: Vec<usize> = self
177            .tracked_objects
178            .iter()
179            .enumerate()
180            .filter(|(_, obj)| !obj.is_initializing && obj.hit_counter_is_positive())
181            .map(|(i, _)| i)
182            .collect();
183
184        let initializing_indices: Vec<usize> = self
185            .tracked_objects
186            .iter()
187            .enumerate()
188            .filter(|(_, obj)| obj.is_initializing)
189            .map(|(i, _)| i)
190            .collect();
191
192        // STAGE 3: Age all tracked objects (predict step) - AFTER categorization
193        for obj in &mut self.tracked_objects {
194            // ReID counter management (BEFORE hit_counter decrement - matches Python)
195            if obj.reid_hit_counter.is_none() {
196                if obj.hit_counter <= 0 {
197                    // Transition to ReID phase
198                    obj.reid_hit_counter = self.config.reid_hit_counter_max;
199                }
200            } else {
201                // Already in ReID phase, decrement
202                obj.reid_hit_counter = obj.reid_hit_counter.map(|c| c - 1);
203            }
204
205            obj.age += 1;
206            // Decrement hit_counter for ALL objects (matches Python/Go behavior)
207            // Matched objects will get +2*period in hit_object(), unmatched decay by 1
208            obj.hit_counter -= 1;
209
210            // Decrement point hit counters
211            for counter in &mut obj.point_hit_counter {
212                *counter = (*counter - 1).max(0);
213            }
214
215            // Kalman predict
216            obj.filter.predict();
217
218            // Update estimate from filter
219            obj.estimate = obj.filter.get_state();
220
221            // Update velocity estimate
222            let state = obj.filter.get_state_vector();
223            let dim_z = obj.filter.dim_z();
224            if state.len() >= dim_z * 2 {
225                let velocity_flat: Vec<f64> =
226                    state.iter().skip(dim_z).take(dim_z).cloned().collect();
227                obj.estimate_velocity =
228                    DMatrix::from_vec(obj.num_points, obj.dim_points, velocity_flat);
229            }
230
231            // Store coordinate transform for later use
232            if let Some(transform) = coord_transform {
233                obj.last_coord_transform = Some(transform.clone_box());
234            }
235        }
236
237        // Match alive initialized objects first (dead objects only participate in ReID)
238        let det_refs: Vec<&Detection> = detections.iter().collect();
239        let alive_init_obj_refs: Vec<&TrackedObject> = alive_initialized_indices
240            .iter()
241            .map(|&i| &self.tracked_objects[i])
242            .collect();
243
244        let distance_matrix = if !alive_init_obj_refs.is_empty() && !det_refs.is_empty() {
245            self.config
246                .distance_function
247                .get_distances(&alive_init_obj_refs, &det_refs)
248        } else {
249            DMatrix::zeros(det_refs.len(), alive_init_obj_refs.len())
250        };
251
252        let (matched_dets, matched_objs) =
253            match_detections_and_objects(&distance_matrix, self.config.distance_threshold);
254
255        // Update matched initialized objects
256        for (&det_idx, &obj_local_idx) in matched_dets.iter().zip(matched_objs.iter()) {
257            let obj_idx = alive_initialized_indices[obj_local_idx];
258            self.hit_object(
259                obj_idx,
260                &detections[det_idx],
261                period,
262                distance_matrix[(det_idx, obj_local_idx)],
263            );
264        }
265
266        // Get unmatched alive initialized objects (for ReID)
267        let unmatched_alive_init_indices: Vec<usize> =
268            get_unmatched(alive_initialized_indices.len(), &matched_objs)
269                .into_iter()
270                .map(|i| alive_initialized_indices[i])
271                .collect();
272
273        // Get unmatched detections
274        let unmatched_det_indices = get_unmatched(detections.len(), &matched_dets);
275
276        // Match initializing objects with unmatched detections
277        let unmatched_det_refs: Vec<&Detection> = unmatched_det_indices
278            .iter()
279            .map(|&i| &detections[i])
280            .collect();
281        let init_obj_refs: Vec<&TrackedObject> = initializing_indices
282            .iter()
283            .map(|&i| &self.tracked_objects[i])
284            .collect();
285
286        let init_distance_matrix = if !init_obj_refs.is_empty() && !unmatched_det_refs.is_empty() {
287            self.config
288                .distance_function
289                .get_distances(&init_obj_refs, &unmatched_det_refs)
290        } else {
291            DMatrix::zeros(unmatched_det_refs.len(), init_obj_refs.len())
292        };
293
294        let (init_matched_dets, init_matched_objs) =
295            match_detections_and_objects(&init_distance_matrix, self.config.distance_threshold);
296
297        // Track matched initializing objects (for ReID)
298        let matched_init_obj_indices: Vec<usize> = init_matched_objs
299            .iter()
300            .map(|&i| initializing_indices[i])
301            .collect();
302
303        // Update matched initializing objects
304        for (&local_det_idx, &obj_local_idx) in
305            init_matched_dets.iter().zip(init_matched_objs.iter())
306        {
307            let det_idx = unmatched_det_indices[local_det_idx];
308            let obj_idx = initializing_indices[obj_local_idx];
309            self.hit_object(
310                obj_idx,
311                &detections[det_idx],
312                period,
313                init_distance_matrix[(local_det_idx, obj_local_idx)],
314            );
315        }
316
317        // STAGE: ReID Matching (if enabled)
318        // Match old objects (unmatched alive initialized + dead) with initializing objects that got matched
319        if let Some(ref reid_distance) = self.config.reid_distance_function {
320            // Collect objects eligible for ReID: unmatched alive initialized + dead
321            let reid_object_indices: Vec<usize> = unmatched_alive_init_indices
322                .iter()
323                .chain(dead_indices.iter())
324                .cloned()
325                .collect();
326
327            // Only process if we have both candidates and objects to match
328            if !reid_object_indices.is_empty() && !matched_init_obj_indices.is_empty() {
329                // Build references for distance computation
330                let reid_obj_refs: Vec<&TrackedObject> = reid_object_indices
331                    .iter()
332                    .map(|&i| &self.tracked_objects[i])
333                    .collect();
334                let candidate_refs: Vec<&TrackedObject> = matched_init_obj_indices
335                    .iter()
336                    .map(|&i| &self.tracked_objects[i])
337                    .collect();
338
339                // Compute distance matrix using TrackedObject estimates as "detections"
340                // Note: reid_distance_function operates on TrackedObjects, using their estimates
341                let reid_distance_matrix =
342                    reid_distance.get_distances_objects(&reid_obj_refs, &candidate_refs);
343
344                // Match using same algorithm as detections
345                let (reid_matched_cands, reid_matched_objs) = match_detections_and_objects(
346                    &reid_distance_matrix,
347                    self.config.reid_distance_threshold,
348                );
349
350                // Process matches: merge old object with new, mark new for removal
351                let mut to_remove: Vec<usize> = vec![];
352                for (&cand_local, &obj_local) in
353                    reid_matched_cands.iter().zip(reid_matched_objs.iter())
354                {
355                    let old_obj_idx = reid_object_indices[obj_local];
356                    let new_obj_idx = matched_init_obj_indices[cand_local];
357
358                    // Get data from new object (need to clone due to borrow rules)
359                    let new_obj_data = self.tracked_objects[new_obj_idx].clone();
360
361                    // Merge: old object takes state from new object
362                    self.tracked_objects[old_obj_idx]
363                        .merge(&new_obj_data, self.config.past_detections_length);
364
365                    to_remove.push(new_obj_idx);
366                }
367
368                // Remove merged new objects (in reverse order to preserve indices)
369                to_remove.sort_unstable();
370                for idx in to_remove.into_iter().rev() {
371                    self.tracked_objects.remove(idx);
372                }
373            }
374        }
375
376        // Create new objects for remaining unmatched detections
377        let still_unmatched: Vec<_> =
378            get_unmatched(unmatched_det_indices.len(), &init_matched_dets)
379                .into_iter()
380                .map(|i| unmatched_det_indices[i])
381                .collect();
382
383        for det_idx in still_unmatched {
384            self.create_object(&detections[det_idx], period, coord_transform);
385        }
386
387        // Return active (non-initializing, non-negative hit_counter) objects
388        // NOTE: Use >= 0 to match Python norfair behavior (objects with hit_counter=0 are still active)
389        self.tracked_objects
390            .iter()
391            .filter(|obj| !obj.is_initializing && obj.hit_counter >= 0)
392            .collect()
393    }
394
395    /// Get the total number of objects that have been assigned permanent IDs.
396    /// Counter starts at 1, so we subtract 1 to get the count.
397    pub fn total_object_count(&self) -> i32 {
398        self.instance_id_counter - 1
399    }
400
401    /// Get the current number of active (non-initializing) objects.
402    pub fn current_object_count(&self) -> usize {
403        self.tracked_objects
404            .iter()
405            .filter(|obj| !obj.is_initializing && obj.hit_counter >= 0)
406            .count()
407    }
408
409    // Internal: update object with matched detection
410    fn hit_object(&mut self, obj_idx: usize, detection: &Detection, period: i32, distance: f64) {
411        // First, build observation matrix while we only need immutable access
412        let h = {
413            let obj = &self.tracked_objects[obj_idx];
414            self.build_observation_matrix_impl(obj, detection)
415        };
416
417        // Now get mutable access for updates
418        let obj = &mut self.tracked_objects[obj_idx];
419
420        // Update hit counter: add 2*period on match (matches Python/Go behavior)
421        // Combined with -1 in tracker_step, matched objects gain net +(2*period - 1)
422        obj.hit_counter = (obj.hit_counter + 2 * period).min(self.config.hit_counter_max);
423
424        // Check for initialization transition
425        // Note: use > not >= to match Python/Go behavior
426        if obj.is_initializing && obj.hit_counter > self.config.initialization_delay {
427            obj.is_initializing = false;
428            obj.id = Some(self.instance_id_counter);
429            self.instance_id_counter += 1;
430            // NOTE: Keep initializing_id - it's a permanent identifier, not just for initialization phase
431
432            // Reset reid_hit_counter if configured
433            if self.config.reid_hit_counter_max.is_some() {
434                obj.reid_hit_counter = None;
435            }
436        }
437
438        // Update point hit counters and detected_at_least_once_points
439        for (i, counter) in obj.point_hit_counter.iter_mut().enumerate() {
440            let score = detection.scores.as_ref().map(|s| s[i]).unwrap_or(1.0);
441            if score > self.config.detection_threshold {
442                *counter = (*counter + period).min(self.config.pointwise_hit_counter_max);
443                // Mark point as detected at least once
444                if i < obj.detected_at_least_once_points.len() {
445                    obj.detected_at_least_once_points[i] = true;
446                }
447            }
448        }
449
450        // Kalman update
451        // IMPORTANT: Use row-major flattening for measurement vector (matches Python/Go)
452        let measurement = DVector::from_vec(to_row_major_vec(detection.get_absolute_points()));
453        obj.filter.update(&measurement, None, h.as_ref());
454
455        // Update estimate
456        obj.estimate = obj.filter.get_state();
457
458        // Store detection
459        obj.last_detection = Some(detection.clone());
460        obj.last_distance = Some(distance);
461
462        // Update past detections
463        if self.config.past_detections_length > 0 {
464            obj.past_detections.push_back(detection.clone());
465            while obj.past_detections.len() > self.config.past_detections_length {
466                obj.past_detections.pop_front();
467            }
468        }
469    }
470
471    // Internal: create new tracked object
472    fn create_object(
473        &mut self,
474        detection: &Detection,
475        period: i32,
476        coord_transform: Option<&dyn CoordinateTransformation>,
477    ) {
478        let global_id = get_next_global_id();
479        let initializing_id = self.initializing_id_counter;
480        self.initializing_id_counter += 1;
481
482        let num_points = detection.num_points();
483        let dim_points = detection.num_dims();
484
485        // Create filter (use enum-based factory for static dispatch)
486        let filter = self
487            .config
488            .filter_factory
489            .create(detection.get_absolute_points());
490
491        // Initialize point hit counters
492        let point_hit_counter = vec![period.min(self.config.pointwise_hit_counter_max); num_points];
493
494        // Initialize detected_at_least_once_points based on detection scores
495        let detected_at_least_once_points = if let Some(ref scores) = detection.scores {
496            scores
497                .iter()
498                .map(|&s| s > self.config.detection_threshold)
499                .collect()
500        } else {
501            vec![true; num_points]
502        };
503
504        let mut obj = TrackedObject {
505            id: None,
506            global_id,
507            initializing_id: Some(initializing_id),
508            age: 0,
509            hit_counter: period,
510            point_hit_counter,
511            last_detection: Some(detection.clone()),
512            last_distance: None,
513            current_min_distance: None,
514            past_detections: VecDeque::new(),
515            label: detection.label.clone(),
516            // reid_hit_counter starts as None; only set to reid_hit_counter_max when
517            // transitioning to ReID phase (hit_counter <= 0) - matches Python behavior
518            reid_hit_counter: None,
519            estimate: filter.get_state(),
520            estimate_velocity: DMatrix::zeros(num_points, dim_points),
521            is_initializing: true,
522            detected_at_least_once_points,
523            filter,
524            initial_period: period,
525            num_points,
526            dim_points,
527            last_coord_transform: coord_transform.map(|t| t.clone_box()),
528        };
529
530        // Check for immediate initialization (delay = 0)
531        if self.config.initialization_delay == 0 {
532            obj.is_initializing = false;
533            obj.id = Some(self.instance_id_counter);
534            self.instance_id_counter += 1;
535            // NOTE: Keep initializing_id - it's a permanent identifier, not just for initialization phase
536        }
537
538        self.tracked_objects.push(obj);
539    }
540
541    // Internal: build observation matrix for partial observations
542    fn build_observation_matrix_impl(
543        &self,
544        obj: &TrackedObject,
545        detection: &Detection,
546    ) -> Option<DMatrix<f64>> {
547        let dim_z = obj.filter.dim_z();
548        let dim_x = obj.filter.dim_x();
549
550        // Check if any points should be masked
551        let scores = detection.scores.as_ref();
552        let needs_mask = scores
553            .map(|s| {
554                s.iter()
555                    .any(|&score| score <= self.config.detection_threshold)
556            })
557            .unwrap_or(false);
558
559        if !needs_mask {
560            return None;
561        }
562
563        // Build H matrix with zeros for masked points
564        let mut h = DMatrix::zeros(dim_z, dim_x);
565        for i in 0..dim_z {
566            let point_idx = i / obj.dim_points;
567            let score = scores.map(|s| s[point_idx]).unwrap_or(1.0);
568            if score > self.config.detection_threshold {
569                h[(i, i)] = 1.0;
570            }
571        }
572
573        Some(h)
574    }
575}
576
577#[cfg(test)]
578mod tests {
579    use super::*;
580    use crate::camera_motion::TranslationTransformation;
581
582    // ===== Basic Tracker Tests =====
583
584    /// Ported from Go: TestTracker_NewTracker
585    #[test]
586    fn test_tracker_new() {
587        let config = TrackerConfig::from_distance_name("euclidean", 100.0);
588        let tracker = Tracker::new(config).unwrap();
589
590        assert_eq!(tracker.tracked_objects.len(), 0);
591        assert_eq!(tracker.total_object_count(), 0);
592        assert_eq!(tracker.current_object_count(), 0);
593    }
594
595    /// Ported from Go: TestTracker_NewTracker (extended)
596    #[test]
597    fn test_tracker_new_with_defaults() {
598        // Test basic tracker creation
599        let mut config = TrackerConfig::from_distance_name("euclidean", 100.0);
600        config.hit_counter_max = 15;
601        config.initialization_delay = -1; // use default: 15/2 = 7
602        config.pointwise_hit_counter_max = 4;
603        config.detection_threshold = 0.0;
604        config.past_detections_length = 4;
605
606        let tracker = Tracker::new(config).unwrap();
607
608        // Verify configuration
609        assert_eq!(tracker.config.distance_threshold, 100.0);
610        assert_eq!(tracker.config.hit_counter_max, 15);
611        assert_eq!(tracker.config.initialization_delay, 7); // 15/2
612
613        // Verify initial state
614        assert_eq!(tracker.tracked_objects.len(), 0);
615        assert_eq!(tracker.current_object_count(), 0);
616        assert_eq!(tracker.total_object_count(), 0);
617    }
618
619    /// Ported from Go: TestTracker_InvalidInitializationDelay
620    #[test]
621    fn test_tracker_invalid_config() {
622        // Test that negative initialization_delay is rejected (note: -1 is sentinel for "use default")
623        let mut config = TrackerConfig::from_distance_name("euclidean", 100.0);
624        config.hit_counter_max = 15;
625        config.initialization_delay = -2; // invalid negative value (not sentinel -1)
626
627        assert!(
628            Tracker::new(config).is_err(),
629            "Expected error for negative initialization_delay"
630        );
631    }
632
633    /// Ported from Go: TestTracker_InvalidInitializationDelay (second case)
634    #[test]
635    fn test_tracker_invalid_config_delay_too_high() {
636        // Test that initialization_delay >= hit_counter_max is rejected
637        let mut config = TrackerConfig::from_distance_name("euclidean", 100.0);
638        config.hit_counter_max = 15;
639        config.initialization_delay = 15; // equal to hit_counter_max (invalid)
640
641        assert!(
642            Tracker::new(config).is_err(),
643            "Expected error for initialization_delay >= hit_counter_max"
644        );
645    }
646
647    /// Ported from Go: TestTracker_SimpleUpdate
648    #[test]
649    fn test_tracker_simple_update() {
650        // Create tracker
651        let mut config = TrackerConfig::from_distance_name("euclidean", 100.0);
652        config.hit_counter_max = 5;
653        config.initialization_delay = -1; // use default: 5/2 = 2
654
655        let mut tracker = Tracker::new(config).unwrap();
656
657        // Create a detection
658        let det = Detection::from_slice(&[10.0, 20.0], 1, 2).unwrap();
659
660        // Update with detection
661        let active = tracker.update(vec![det], 1, None);
662
663        // Should have 0 active objects (still initializing)
664        assert_eq!(active.len(), 0, "Expected 0 active objects (initializing)");
665
666        // Should have 1 tracked object total
667        assert_eq!(
668            tracker.tracked_objects.len(),
669            1,
670            "Expected 1 tracked object"
671        );
672
673        // Total count should be 0 (object hasn't gotten permanent ID yet)
674        assert_eq!(
675            tracker.total_object_count(),
676            0,
677            "Expected total count 0 (still initializing)"
678        );
679
680        // Object should be initializing
681        assert!(
682            tracker.tracked_objects[0].is_initializing,
683            "Expected object to be initializing"
684        );
685
686        // Object should have initializing ID but not permanent ID
687        assert!(
688            tracker.tracked_objects[0].initializing_id.is_some(),
689            "Expected initializing ID to be set"
690        );
691        assert!(
692            tracker.tracked_objects[0].id.is_none(),
693            "Expected permanent ID to be nil (still initializing)"
694        );
695    }
696
697    /// Ported from Go: TestTracker_UpdateEmptyDetections
698    #[test]
699    fn test_tracker_update_empty_detections() {
700        // Create tracker
701        let mut config = TrackerConfig::from_distance_name("euclidean", 100.0);
702        config.hit_counter_max = 5;
703        config.initialization_delay = -1; // use default
704
705        let mut tracker = Tracker::new(config).unwrap();
706
707        // Update with no detections (empty vec)
708        let active = tracker.update(vec![], 1, None);
709
710        assert_eq!(active.len(), 0, "Expected 0 active objects");
711
712        // Update again with empty vec
713        let active = tracker.update(Vec::new(), 1, None);
714
715        assert_eq!(active.len(), 0, "Expected 0 active objects");
716    }
717
718    #[test]
719    fn test_tracker_initialization() {
720        let mut config = TrackerConfig::from_distance_name("euclidean", 100.0);
721        config.hit_counter_max = 5;
722        config.initialization_delay = 2;
723
724        let mut tracker = Tracker::new(config).unwrap();
725
726        // First update - initializing (hit_counter = 1 on creation)
727        let det = Detection::from_slice(&[10.0, 20.0], 1, 2).unwrap();
728        let active = tracker.update(vec![det.clone()], 1, None);
729        assert_eq!(active.len(), 0);
730
731        // Second update - still initializing
732        // All objects decay: 1 -> 0, then match: +2 = 2, but 2 > 2 is false
733        let active = tracker.update(vec![det.clone()], 1, None);
734        assert_eq!(active.len(), 0);
735
736        // Third update - should be initialized now (hit_counter > initialization_delay)
737        // 2 -> 1, then match: +2 = 3, and 3 > 2 is true
738        let active = tracker.update(vec![det], 1, None);
739        assert_eq!(active.len(), 1);
740        assert!(active[0].id.is_some());
741    }
742
743    // ===== Detection Tests =====
744
745    /// Ported from Go: TestDetection_Creation
746    #[test]
747    fn test_detection_creation_2d() {
748        // Test valid 2D points
749        let det = Detection::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2).unwrap();
750
751        // Verify points shape
752        assert_eq!(det.points.nrows(), 3, "Expected 3 rows");
753        assert_eq!(det.points.ncols(), 2, "Expected 2 cols");
754    }
755
756    /// Ported from Go: TestDetection_Creation (3D case)
757    #[test]
758    fn test_detection_creation_3d() {
759        // Test valid 3D points
760        let det = Detection::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3).unwrap();
761
762        // Verify points shape
763        assert_eq!(det.points.nrows(), 2, "Expected 2 rows");
764        assert_eq!(det.points.ncols(), 3, "Expected 3 cols");
765    }
766
767    // ===== TrackedObject Tests =====
768
769    /// Ported from Go: TestTrackedObject_Creation
770    #[test]
771    fn test_tracked_object_creation_via_tracker() {
772        // Create tracker with initialization_delay > 0
773        let mut config = TrackerConfig::from_distance_name("euclidean", 100.0);
774        config.hit_counter_max = 15;
775        config.initialization_delay = 7;
776
777        let mut tracker = Tracker::new(config).unwrap();
778
779        // Create detection with 2 points
780        let det = Detection::from_slice(&[10.0, 20.0, 30.0, 40.0], 2, 2).unwrap();
781
782        // Update tracker to create object
783        tracker.update(vec![det], 1, None);
784
785        // Verify object was created
786        assert_eq!(tracker.tracked_objects.len(), 1);
787        let obj = &tracker.tracked_objects[0];
788
789        // Verify initialization
790        assert_eq!(obj.num_points, 2, "Expected 2 points");
791        assert_eq!(obj.dim_points, 2, "Expected 2D points");
792        assert_eq!(obj.hit_counter, 1, "Expected hit counter 1");
793        assert!(obj.is_initializing, "Expected object to be initializing");
794        assert!(
795            obj.initializing_id.is_some(),
796            "Expected initializing ID to be set"
797        );
798        assert!(
799            obj.id.is_none(),
800            "Expected permanent ID to be nil (still initializing)"
801        );
802    }
803
804    // ===== Camera Motion Tests =====
805
806    /// Ported from Go: TestTracker_CameraMotion
807    #[test]
808    fn test_tracker_camera_motion() {
809        // Create tracker with euclidean distance, threshold=1, initialization_delay=0
810        let mut config = TrackerConfig::from_distance_name("euclidean", 1.0);
811        config.hit_counter_max = 1;
812        config.initialization_delay = 0; // no initialization delay
813
814        let mut tracker = Tracker::new(config).unwrap();
815
816        // Setup: movement_vector = [1, 1]
817        // So abs_to_rel adds (1,1) and rel_to_abs subtracts (1,1)
818        // If relative_points = [2, 2], then absolute_points = rel_to_abs([2,2]) = [1, 1]
819        let coord_transform = TranslationTransformation::new([1.0, 1.0]);
820
821        // Create detection with relative points [2, 2]
822        let det = Detection::from_slice(&[2.0, 2.0], 1, 2).unwrap();
823
824        // Update tracker with coordinate transformation
825        let active = tracker.update(vec![det], 1, Some(&coord_transform));
826
827        // Should have 1 active object (initialization_delay = 0)
828        assert_eq!(active.len(), 1, "Expected 1 active object");
829
830        let obj = active[0];
831
832        // Verify estimate (should be in absolute coordinates in internal state,
833        // but estimate is kept in relative coordinates by default)
834        // The filter is initialized with absolute points, so estimate reflects that
835        // We need to verify the transformation was applied correctly
836
837        // Note: The Rust implementation keeps estimate in the coordinate system
838        // used for filter initialization (absolute when transform provided)
839        // This is different from Go which transforms back to relative
840
841        // Just verify the object was created and has the right shape
842        assert_eq!(obj.num_points, 1);
843        assert_eq!(obj.dim_points, 2);
844    }
845
846    /// Test immediate initialization (delay = 0)
847    #[test]
848    fn test_tracker_immediate_initialization() {
849        let mut config = TrackerConfig::from_distance_name("euclidean", 100.0);
850        config.hit_counter_max = 5;
851        config.initialization_delay = 0;
852
853        let mut tracker = Tracker::new(config).unwrap();
854
855        // First detection should immediately get a permanent ID
856        let det = Detection::from_slice(&[10.0, 20.0], 1, 2).unwrap();
857        let active = tracker.update(vec![det], 1, None);
858
859        // Should have 1 active object immediately
860        assert_eq!(active.len(), 1, "Expected 1 active object with delay=0");
861        assert!(active[0].id.is_some(), "Expected permanent ID with delay=0");
862        assert!(
863            !active[0].is_initializing,
864            "Should not be initializing with delay=0"
865        );
866
867        // Total count should be 1
868        assert_eq!(tracker.total_object_count(), 1);
869    }
870
871    /// Test object count methods
872    #[test]
873    fn test_tracker_object_counts() {
874        let mut config = TrackerConfig::from_distance_name("euclidean", 100.0);
875        config.hit_counter_max = 5;
876        config.initialization_delay = 0; // immediate initialization
877
878        let mut tracker = Tracker::new(config).unwrap();
879
880        // Initially both counts should be 0
881        assert_eq!(tracker.total_object_count(), 0);
882        assert_eq!(tracker.current_object_count(), 0);
883
884        // Add first object
885        let det1 = Detection::from_slice(&[10.0, 20.0], 1, 2).unwrap();
886        tracker.update(vec![det1], 1, None);
887
888        assert_eq!(tracker.total_object_count(), 1);
889        assert_eq!(tracker.current_object_count(), 1);
890
891        // Add second object (far enough to not match first)
892        let det2 = Detection::from_slice(&[1000.0, 2000.0], 1, 2).unwrap();
893        tracker.update(vec![det2], 1, None);
894
895        assert_eq!(tracker.total_object_count(), 2);
896        // First object may have died (hit_counter decayed), check current count
897        // Since we're not matching, objects decay
898    }
899
900    // ===== Python Tracker Tests (ported from test_tracker.py) =====
901
902    /// Ported from Python: test_params (bad distance name)
903    #[test]
904    #[should_panic(expected = "Unknown distance function")]
905    fn test_tracker_params_bad_distance() {
906        let config = TrackerConfig::from_distance_name("_bad_distance", 10.0);
907        // This should panic when creating the tracker because distance function lookup fails
908        Tracker::new(config).unwrap();
909    }
910
911    /// Ported from Python: test_simple (hit counter dynamics)
912    /// Tests delay initialization and hit counter capping
913    #[test]
914    fn test_tracker_simple_hit_counter_dynamics() {
915        let delay = 1;
916        let counter_max = delay + 2; // = 3
917
918        let mut config = TrackerConfig::from_distance_name("euclidean", 100.0);
919        config.hit_counter_max = counter_max;
920        config.initialization_delay = delay;
921
922        let mut tracker = Tracker::new(config).unwrap();
923
924        let det = Detection::from_slice(&[1.0, 1.0], 1, 2).unwrap();
925
926        // Test the delay phase (object is initializing)
927        for _age in 0..delay {
928            let active = tracker.update(vec![det.clone()], 1, None);
929            assert_eq!(active.len(), 0, "Expected 0 active objects during delay");
930        }
931
932        // After delay, object becomes active and should have hit_counter = delay+1
933        let active = tracker.update(vec![det.clone()], 1, None);
934        assert_eq!(active.len(), 1, "Expected 1 active object after delay");
935
936        // Continue updating to see hit_counter cap at counter_max
937        for _ in 0..5 {
938            let active = tracker.update(vec![det.clone()], 1, None);
939            assert_eq!(active.len(), 1);
940            assert!(
941                active[0].hit_counter <= counter_max,
942                "Hit counter should be capped at {}, got {}",
943                counter_max,
944                active[0].hit_counter
945            );
946        }
947
948        // Now update without detections - hit_counter should decrease
949        let mut prev_counter = counter_max;
950        for _ in 0..counter_max {
951            let active = tracker.update(vec![], 1, None);
952            if active.len() == 1 {
953                assert!(
954                    active[0].hit_counter < prev_counter,
955                    "Hit counter should decrease without detections"
956                );
957                prev_counter = active[0].hit_counter;
958            }
959        }
960
961        // Object should disappear when hit_counter reaches 0
962        let active = tracker.update(vec![], 1, None);
963        assert_eq!(
964            active.len(),
965            0,
966            "Object should disappear when hit_counter reaches 0"
967        );
968    }
969
970    /// Ported from Python: test_moving
971    /// Test a moving object and verify velocity estimation
972    #[test]
973    fn test_tracker_moving_object() {
974        let mut config = TrackerConfig::from_distance_name("euclidean", 100.0);
975        config.hit_counter_max = 5;
976        config.initialization_delay = 0; // Use immediate initialization
977
978        let mut tracker = Tracker::new(config).unwrap();
979
980        // Update with moving detections along y-axis
981        // y: 1 -> 2 -> 3 -> 4
982        let active = tracker.update(
983            vec![Detection::from_slice(&[1.0, 1.0], 1, 2).unwrap()],
984            1,
985            None,
986        );
987        assert_eq!(
988            active.len(),
989            1,
990            "First detection should create active object"
991        );
992
993        tracker.update(
994            vec![Detection::from_slice(&[1.0, 2.0], 1, 2).unwrap()],
995            1,
996            None,
997        );
998        tracker.update(
999            vec![Detection::from_slice(&[1.0, 3.0], 1, 2).unwrap()],
1000            1,
1001            None,
1002        );
1003        let active = tracker.update(
1004            vec![Detection::from_slice(&[1.0, 4.0], 1, 2).unwrap()],
1005            1,
1006            None,
1007        );
1008
1009        assert_eq!(active.len(), 1, "Expected 1 active object");
1010
1011        // Check that estimated position makes sense
1012        // x should be close to 1, y should be between 3 and 4 (filter smoothing)
1013        let estimate = &active[0].estimate;
1014        assert!(
1015            (estimate[(0, 0)] - 1.0).abs() < 0.5,
1016            "X should be close to 1.0, got {}",
1017            estimate[(0, 0)]
1018        );
1019        assert!(
1020            estimate[(0, 1)] > 3.0 && estimate[(0, 1)] <= 4.5,
1021            "Y should be between 3 and 4.5, got {}",
1022            estimate[(0, 1)]
1023        );
1024    }
1025
1026    /// Ported from Python: test_distance_t
1027    /// Test distance threshold filtering - objects too far shouldn't match
1028    #[test]
1029    fn test_tracker_distance_threshold() {
1030        let mut config = TrackerConfig::from_distance_name("euclidean", 0.5); // small threshold
1031        config.hit_counter_max = 5;
1032        config.initialization_delay = 0; // immediate initialization
1033
1034        let mut tracker = Tracker::new(config).unwrap();
1035
1036        // First detection creates an object at (1.0, 1.0)
1037        let active = tracker.update(
1038            vec![Detection::from_slice(&[1.0, 1.0], 1, 2).unwrap()],
1039            1,
1040            None,
1041        );
1042        assert_eq!(active.len(), 1, "First detection should create object");
1043
1044        // Second detection at (1.0, 2.0) is distance 1.0 away, which > threshold 0.5
1045        // So it should create a NEW object, not match the existing one
1046        // Each update without matching causes existing objects to decay
1047        let active = tracker.update(
1048            vec![Detection::from_slice(&[1.0, 2.0], 1, 2).unwrap()],
1049            1,
1050            None,
1051        );
1052        // We should have 2 objects now (first one decaying, second new)
1053        assert!(active.len() >= 1, "Should have at least 1 object");
1054
1055        // A closer point (0.3 away) should match
1056        let active = tracker.update(
1057            vec![Detection::from_slice(&[1.0, 2.3], 1, 2).unwrap()],
1058            1,
1059            None,
1060        );
1061        assert!(
1062            active.len() >= 1,
1063            "Expected match when distance < threshold"
1064        );
1065    }
1066
1067    /// Ported from Python: test_1d_points
1068    /// Test that 1D point arrays are correctly handled
1069    #[test]
1070    fn test_tracker_1d_points() {
1071        let mut config = TrackerConfig::from_distance_name("euclidean", 100.0);
1072        config.hit_counter_max = 5;
1073        config.initialization_delay = 0;
1074
1075        let mut tracker = Tracker::new(config).unwrap();
1076
1077        // Create detection with 1D points [x, y] which should be treated as [[x, y]]
1078        let det = Detection::from_slice(&[1.0, 1.0], 1, 2).unwrap();
1079
1080        // Detection should have shape (1, 2)
1081        assert_eq!(det.points.nrows(), 1);
1082        assert_eq!(det.points.ncols(), 2);
1083
1084        let active = tracker.update(vec![det], 1, None);
1085        assert_eq!(active.len(), 1, "Expected 1 active object");
1086
1087        // Tracked object estimate should also have shape (1, 2)
1088        assert_eq!(active[0].estimate.nrows(), 1);
1089        assert_eq!(active[0].estimate.ncols(), 2);
1090    }
1091
1092    /// Ported from Python: test_count (comprehensive)
1093    /// Test total_object_count and current_object_count methods
1094    #[test]
1095    fn test_tracker_count_comprehensive() {
1096        let delay = 1;
1097        let counter_max = delay + 2; // = 3
1098
1099        let mut config = TrackerConfig::from_distance_name("euclidean", 1.0);
1100        config.hit_counter_max = counter_max;
1101        config.initialization_delay = delay;
1102
1103        let mut tracker = Tracker::new(config).unwrap();
1104
1105        let det1 = Detection::from_slice(&[1.0, 1.0], 1, 2).unwrap();
1106
1107        // During delay phase
1108        for _ in 0..delay {
1109            let active = tracker.update(vec![det1.clone()], 1, None);
1110            assert_eq!(active.len(), 0);
1111            assert_eq!(
1112                tracker.total_object_count(),
1113                0,
1114                "Total count should be 0 during init"
1115            );
1116            assert_eq!(
1117                tracker.current_object_count(),
1118                0,
1119                "Current count should be 0 during init"
1120            );
1121        }
1122
1123        // After delay, object becomes active
1124        let active = tracker.update(vec![det1.clone()], 1, None);
1125        assert_eq!(active.len(), 1);
1126        assert_eq!(tracker.total_object_count(), 1);
1127        assert_eq!(tracker.current_object_count(), 1);
1128
1129        // Object decays without detections but stays active for a while
1130        for _ in 0..counter_max - 1 {
1131            let active = tracker.update(vec![], 1, None);
1132            if !active.is_empty() {
1133                assert_eq!(tracker.total_object_count(), 1);
1134                assert_eq!(tracker.current_object_count(), 1);
1135            }
1136        }
1137
1138        // Object dies
1139        let active = tracker.update(vec![], 1, None);
1140        assert_eq!(active.len(), 0);
1141        assert_eq!(
1142            tracker.total_object_count(),
1143            1,
1144            "Total should stay 1 after object dies"
1145        );
1146        assert_eq!(
1147            tracker.current_object_count(),
1148            0,
1149            "Current should be 0 after object dies"
1150        );
1151
1152        // Add two new objects (far apart so they don't match each other)
1153        let det2 = Detection::from_slice(&[100.0, 100.0], 1, 2).unwrap();
1154        let det3 = Detection::from_slice(&[200.0, 200.0], 1, 2).unwrap();
1155
1156        // During delay phase for new objects
1157        for _ in 0..delay {
1158            let active = tracker.update(vec![det2.clone(), det3.clone()], 1, None);
1159            assert_eq!(active.len(), 0);
1160            assert_eq!(
1161                tracker.total_object_count(),
1162                1,
1163                "Total should still be 1 during init"
1164            );
1165            assert_eq!(tracker.current_object_count(), 0);
1166        }
1167
1168        // After delay, new objects become active
1169        let active = tracker.update(vec![det2, det3], 1, None);
1170        assert_eq!(active.len(), 2);
1171        assert_eq!(
1172            tracker.total_object_count(),
1173            3,
1174            "Total should be 3 (1 dead + 2 new)"
1175        );
1176        assert_eq!(tracker.current_object_count(), 2);
1177    }
1178
1179    /// Ported from Python: test_multiple_trackers
1180    /// Test that multiple trackers are independent
1181    #[test]
1182    fn test_multiple_trackers_independent() {
1183        let mut config1 = TrackerConfig::from_distance_name("euclidean", 1.0);
1184        config1.hit_counter_max = 2;
1185        config1.initialization_delay = 0;
1186
1187        let mut config2 = TrackerConfig::from_distance_name("euclidean", 1.0);
1188        config2.hit_counter_max = 2;
1189        config2.initialization_delay = 0;
1190
1191        let mut tracker1 = Tracker::new(config1).unwrap();
1192        let mut tracker2 = Tracker::new(config2).unwrap();
1193
1194        let det1 = Detection::from_slice(&[1.0, 1.0], 1, 2).unwrap();
1195        let det2 = Detection::from_slice(&[2.0, 2.0], 1, 2).unwrap();
1196
1197        let active1 = tracker1.update(vec![det1], 1, None);
1198        assert_eq!(active1.len(), 1);
1199
1200        let active2 = tracker2.update(vec![det2], 1, None);
1201        assert_eq!(active2.len(), 1);
1202
1203        // Trackers should have independent counts
1204        assert_eq!(tracker1.total_object_count(), 1);
1205        assert_eq!(tracker2.total_object_count(), 1);
1206
1207        // Objects should have different IDs (from different global ID pools)
1208        // Note: This depends on implementation - Rust uses factory pattern
1209    }
1210}