use image::DynamicImage;
use opencv::{
core::{Mat, Point, Rect, Scalar, Size},
imgproc,
prelude::*,
videoio::{self, VideoCapture, VideoWriter},
};
use ort::session::{Session, builder::GraphOptimizationLevel};
use std::error::Error;
use std::time::Instant;
use trackforge::trackers::deepsort::DeepSort;
use trackforge::traits::AppearanceExtractor;
use trackforge::types::BoundingBox;
use usls::{Config, DType, models::RTDETR};
struct ReIDExtractor {
session: Session,
}
impl ReIDExtractor {
fn new(model_path: &str) -> Result<Self, Box<dyn Error>> {
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(4)?
.commit_from_file(model_path)?;
Ok(Self { session })
}
fn preprocess(&self, image: &DynamicImage) -> ndarray::Array4<f32> {
let resized = image.resize_exact(224, 224, image::imageops::FilterType::Triangle);
let rgb = resized.to_rgb8();
let mean = [0.485, 0.456, 0.406];
let std = [0.229, 0.224, 0.225];
let mut array = ndarray::Array4::<f32>::zeros((1, 3, 224, 224));
for (x, y, pixel) in rgb.enumerate_pixels() {
let (x, y) = (x as usize, y as usize);
let [r, g, b] = pixel.0;
array[[0, 0, y, x]] = (r as f32 / 255.0 - mean[0]) / std[0];
array[[0, 1, y, x]] = (g as f32 / 255.0 - mean[1]) / std[1];
array[[0, 2, y, x]] = (b as f32 / 255.0 - mean[2]) / std[2];
}
array
}
}
impl AppearanceExtractor for ReIDExtractor {
fn extract(
&mut self,
image: &DynamicImage,
bboxes: &[BoundingBox],
) -> Result<Vec<Vec<f32>>, Box<dyn Error>> {
let mut embeddings = Vec::new();
for bbox in bboxes {
let crop = image.crop_imm(
bbox.x as u32,
bbox.y as u32,
bbox.width as u32,
bbox.height as u32,
);
let input = self.preprocess(&crop);
let shape = input.shape().to_vec();
let data = input.into_raw_vec_and_offset().0;
let input_tensor = ort::value::Tensor::from_array((shape, data))?;
let input_name = self.session.inputs[0].name.clone();
let outputs = self
.session
.run(ort::inputs![input_name.as_str() => input_tensor])?;
let output_tuple = outputs[0].try_extract_tensor::<f32>()?;
let output_slice = output_tuple.1;
let embedding: Vec<f32> = output_slice.iter().cloned().collect();
let norm: f32 = embedding.iter().map(|v| v * v).sum::<f32>().sqrt();
let normalized = if norm > 1e-6 {
embedding.iter().map(|v| v / norm).collect()
} else {
embedding
};
embeddings.push(normalized);
}
Ok(embeddings)
}
}
struct Detector {
model: RTDETR,
}
impl Detector {
fn new(model_path: &str) -> Result<Self, Box<dyn Error>> {
let config = Config::rtdetr()
.with_model_file(model_path)
.with_dtype_all("f32".parse::<DType>()?)
.commit()?;
let model = RTDETR::new(config)?;
Ok(Self { model })
}
fn run(&mut self, img: &DynamicImage) -> Result<Vec<(BoundingBox, f32, i64)>, Box<dyn Error>> {
let image: usls::Image = img.clone().into();
let xs = vec![image];
let ys = self.model.forward(&xs)?;
let mut result = Vec::new();
if let Some(y) = ys.first() {
for obj in &y.hbbs {
let x = obj.xmin();
let y = obj.ymin();
let w = obj.xmax() - obj.xmin();
let h = obj.ymax() - obj.ymin();
let conf = obj.confidence().unwrap_or(0.0);
let cls = obj.id().unwrap_or(0);
let bbox = BoundingBox::new(x, y, w, h);
result.push((bbox, conf, cls as i64));
}
}
Ok(result)
}
}
fn main() -> Result<(), Box<dyn Error>> {
let args: Vec<String> = std::env::args().collect();
if args.len() < 5 {
println!("Usage: deepsort_ort <video_path> <model_path> <reid_path> <output_path>");
return Ok(());
}
let video_path = &args[1];
let model_path = &args[2];
let reid_path = &args[3];
let output_path = &args[4];
println!("Loading models...");
let mut detector = Detector::new(model_path)?;
let extractor = ReIDExtractor::new(reid_path)?;
let mut tracker = DeepSort::new(extractor, 60, 3, 0.7, 0.2, 100);
let mut cam = VideoCapture::from_file(video_path, videoio::CAP_ANY)?;
if !cam.is_opened()? {
return Err("Unable to open video file".into());
}
let width = cam.get(videoio::CAP_PROP_FRAME_WIDTH)? as i32;
let height = cam.get(videoio::CAP_PROP_FRAME_HEIGHT)? as i32;
let fps = cam.get(videoio::CAP_PROP_FPS)?;
let frame_count_total = cam.get(videoio::CAP_PROP_FRAME_COUNT)? as i32;
println!(
"Video info: {}x{} @ {}fps, {} frames",
width, height, fps, frame_count_total
);
let mut writer = VideoWriter::new(
output_path,
VideoWriter::fourcc('a', 'v', 'c', '1')?,
fps,
Size::new(width, height),
true,
)?;
println!("Processing video...");
let mut frame = Mat::default();
let mut frame_count = 0;
loop {
if frame_count > 300 {
println!("Reached frame limit (300), stopping.");
break;
}
if !cam.read(&mut frame)? {
break;
}
if frame.empty() {
continue;
}
let mut rgb = Mat::default();
imgproc::cvt_color(
&frame,
&mut rgb,
imgproc::COLOR_BGR2RGB,
0,
opencv::core::AlgorithmHint::ALGO_HINT_DEFAULT,
)?;
let data = rgb.data_bytes()?;
let img = image::RgbImage::from_raw(width as u32, height as u32, data.to_vec())
.ok_or("Failed to create image")?;
let dyn_img = DynamicImage::ImageRgb8(img);
let start = Instant::now();
let detections = detector.run(&dyn_img)?;
let tracks = tracker.update(&dyn_img, detections)?;
let duration = start.elapsed();
for track in &tracks {
let tlwh = track.to_tlwh();
let rect = Rect::new(
tlwh[0] as i32,
tlwh[1] as i32,
tlwh[2] as i32,
tlwh[3] as i32,
);
let color = Scalar::new(0.0, 255.0, 0.0, 0.0);
imgproc::rectangle(&mut frame, rect, color, 2, imgproc::LINE_8, 0)?;
let label = format!("ID: {}", track.track_id);
imgproc::put_text(
&mut frame,
&label,
Point::new(rect.x, rect.y - 5),
imgproc::FONT_HERSHEY_SIMPLEX,
0.5,
color,
1,
imgproc::LINE_8,
false,
)?;
}
imgproc::put_text(
&mut frame,
&format!("FPS: {:.1}", 1.0 / duration.as_secs_f32()),
Point::new(10, 30),
imgproc::FONT_HERSHEY_SIMPLEX,
1.0,
Scalar::new(0.0, 0.0, 255.0, 0.0),
2,
imgproc::LINE_8,
false,
)?;
println!("Frame {}: {} tracks", frame_count, tracks.len());
writer.write(&frame)?;
frame_count += 1;
}
println!("Done! Output saved to {}", output_path);
Ok(())
}