import argparse
import time
from pathlib import Path
import cv2
import numpy as np
import torch
import torchvision.models as models
import torchvision.transforms as T
from PIL import Image
from ultralytics import YOLO
import trackforge
def get_embedder():
print("🧠 Loading standard ResNet18 for embeddings...")
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
model.fc = torch.nn.Identity()
model.eval()
transform = T.Compose([
T.Resize((128, 64)), T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
return model, transform
def extract_features(model, transform, frame, bboxes):
if not bboxes:
return []
crops = []
h, w, _ = frame.shape
for bbox in bboxes:
x1, y1, w_box, h_box = bbox
x2, y2 = x1 + w_box, y1 + h_box
x1 = max(0, int(x1))
y1 = max(0, int(y1))
x2 = min(w, int(x2))
y2 = min(h, int(y2))
if x2 <= x1 or y2 <= y1:
crop = np.zeros((128, 64, 3), dtype=np.uint8)
else:
crop = frame[y1:y2, x1:x2]
crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
crops.append(Image.fromarray(crop))
tensors = [transform(img) for img in crops]
batch = torch.stack(tensors)
with torch.no_grad():
features = model(batch)
features = torch.nn.functional.normalize(features, p=2, dim=1)
return features.numpy().tolist()
def run_tracking(args):
video_path = args.video
output_path = args.output
model_path = args.model
print(f"🚀 Loading YOLO model: {model_path}")
yolo = YOLO(model_path)
embedder, transform = get_embedder()
print("📦 Initializing Deep SORT tracker...")
tracker = trackforge.DeepSort(
max_age=30,
n_init=3,
max_iou_distance=0.7,
max_cosine_distance=0.2,
nn_budget=100
)
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(f"❌ Error opening video file: {video_path}")
return
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
out = cv2.VideoWriter(
output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height)
)
frame_count = 0
t0 = time.time()
colors = np.random.randint(0, 255, (1000, 3)).tolist()
print("🎬 Processing video...")
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frame_count += 1
results = yolo.predict(frame, verbose=False, classes=[0])
detections = []
for r in results:
for box in r.boxes:
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
w_box = x2 - x1
h_box = y2 - y1
conf = float(box.conf[0])
cls = int(box.cls[0])
detections.append((
[float(x1), float(y1), float(w_box), float(h_box)],
conf,
cls
))
bboxes = [d[0] for d in detections]
embeddings = extract_features(embedder, transform, frame, bboxes)
tracks = tracker.update(detections, embeddings)
for track in tracks:
tid = track.track_id
x1, y1, w, h = track.tlwh
x2, y2 = x1 + w, y1 + h
color = colors[tid % len(colors)]
cv2.rectangle(
frame, (int(x1), int(y1)), (int(x2), int(y2)), color, 2
)
label = f"ID:{tid}"
cv2.putText(
frame, label, (int(x1), int(y1) - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2
)
out.write(frame)
if frame_count % 20 == 0:
print(f"Frame {frame_count} - Tracks: {len(tracks)}")
elapsed = time.time() - t0
print(f"✅ Done! Processed {frame_count} frames in {elapsed:.2f}s")
print(f"📹 Saved to {output_path}")
cap.release()
out.release()
def main():
parser = argparse.ArgumentParser(
description="Deep SORT tracking with YOLO detection and ResNet18 embeddings."
)
parser.add_argument(
"--video", type=str, default="people.mp4",
help="Path to input video"
)
parser.add_argument(
"--output", type=str, default="output_deepsort.mp4",
help="Path to output video"
)
parser.add_argument(
"--model", type=str, default="yolo11n.pt",
help="YOLO model path"
)
args = parser.parse_args()
if not Path(args.video).exists():
print(f"⚠️ Video {args.video} not found.")
return
run_tracking(args)
if __name__ == "__main__":
main()