mod nn_matching;
mod track;
mod tracker;
#[cfg(feature = "python")]
pub mod python;
pub use nn_matching::Metric;
pub use track::{Track, TrackState};
pub use tracker::DeepSortTracker;
use crate::traits::AppearanceExtractor;
use crate::types::BoundingBox;
use image::DynamicImage;
use nn_matching::NearestNeighborDistanceMetric;
use std::error::Error;
pub struct DeepSort<E: AppearanceExtractor> {
extractor: E,
tracker: DeepSortTracker,
}
impl<E: AppearanceExtractor> DeepSort<E> {
pub fn new(
extractor: E,
max_age: usize,
n_init: usize,
max_iou_distance: f32,
max_cosine_distance: f32,
nn_budget: usize,
) -> Self {
let metric = NearestNeighborDistanceMetric::new(
Metric::Cosine,
max_cosine_distance,
Some(nn_budget),
);
let tracker = DeepSortTracker::new(metric, max_age, n_init, max_iou_distance);
Self { extractor, tracker }
}
pub fn update(
&mut self,
image: &DynamicImage,
detections: Vec<(BoundingBox, f32, i64)>,
) -> Result<Vec<Track>, Box<dyn Error>> {
self.tracker.predict();
let bboxes: Vec<BoundingBox> = detections.iter().map(|(bbox, _, _)| *bbox).collect();
let embeddings = if !bboxes.is_empty() {
self.extractor.extract(image, &bboxes)?
} else {
Vec::new()
};
self.tracker.update(&detections, &embeddings);
Ok(self
.tracker
.tracks
.iter()
.filter(|t| t.is_confirmed() && t.time_since_update == 0) .cloned()
.collect())
}
}
impl<E: AppearanceExtractor> DeepSort<E> {
pub fn new_default(extractor: E) -> Self {
Self::new(extractor, 70, 3, 0.7, 0.2, 100)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::BoundingBox;
struct MockExtractor;
impl AppearanceExtractor for MockExtractor {
fn extract(
&mut self,
_image: &DynamicImage,
bboxes: &[BoundingBox],
) -> Result<Vec<Vec<f32>>, Box<dyn Error>> {
Ok(vec![vec![1.0, 0.0]; bboxes.len()])
}
}
#[test]
fn test_deepsort_initialization() {
let tracker = DeepSort::new_default(MockExtractor);
assert_eq!(tracker.tracker.tracks.len(), 0);
}
#[test]
fn test_deepsort_track_lifecycle() {
let mut tracker = DeepSort::new_default(MockExtractor);
let image = DynamicImage::new_rgb8(100, 100);
let det1 = vec![(BoundingBox::new(10.0, 10.0, 20.0, 20.0), 0.9, 0)];
let tracks = tracker.update(&image, det1.clone()).unwrap();
assert!(tracks.is_empty());
assert_eq!(tracker.tracker.tracks.len(), 1);
assert!(!tracker.tracker.tracks[0].is_confirmed());
let tracks = tracker.update(&image, det1.clone()).unwrap();
assert!(tracks.is_empty());
let tracks = tracker.update(&image, det1).unwrap();
assert_eq!(tracks.len(), 1);
assert!(tracks[0].is_confirmed());
}
}