#![doc = include_str!("README.md")]
use crate::utils::kalman::{CovarianceMatrix, KalmanFilter, MeasurementVector, StateVector};
use std::collections::HashSet;
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
pub enum OcSortTrackState {
Tentative,
Confirmed,
Deleted,
}
#[derive(Debug, Clone)]
pub struct OcSortTrack {
pub tlwh: [f32; 4],
pub score: f32,
pub class_id: i64,
pub track_id: u64,
pub state: OcSortTrackState,
pub hits: usize,
pub hit_streak: usize,
pub time_since_update: usize,
pub age: usize,
mean: StateVector,
covariance: CovarianceMatrix,
observations: Vec<(MeasurementVector, usize)>,
}
impl OcSortTrack {
fn new(
tlwh: [f32; 4],
score: f32,
class_id: i64,
track_id: u64,
frame_id: usize,
kf: &KalmanFilter,
) -> Self {
let xyah = tlwh_to_xyah(&tlwh);
let (mean, covariance) = kf.initiate(&xyah);
let observations = vec![(xyah, frame_id)];
Self {
tlwh,
score,
class_id,
track_id,
state: OcSortTrackState::Tentative,
hits: 1,
hit_streak: 1,
time_since_update: 0,
age: 1,
mean,
covariance,
observations,
}
}
fn predict(&mut self, kf: &KalmanFilter) {
if self.time_since_update > 0 {
self.hit_streak = 0;
}
let (mean, covariance) = kf.predict(&self.mean, &self.covariance);
self.mean = mean;
self.covariance = covariance;
self.age += 1;
self.time_since_update += 1;
self.tlwh = xyah_to_tlwh(&self.mean);
}
fn update_kf(&mut self, xyah: &MeasurementVector, kf: &KalmanFilter) {
let (mean, covariance) = kf.update(&self.mean, &self.covariance, xyah);
self.mean = mean;
self.covariance = covariance;
self.tlwh = xyah_to_tlwh(&self.mean);
}
fn obs_direction(&self, delta_t: usize) -> Option<[f32; 2]> {
let n = self.observations.len();
if n < 2 {
return None;
}
let anchor_idx = n.saturating_sub(delta_t + 1);
let (obs_old, _) = &self.observations[anchor_idx];
let (obs_new, _) = &self.observations[n - 1];
let dy = obs_new[1] - obs_old[1];
let dx = obs_new[0] - obs_old[0];
let norm = (dy * dy + dx * dx).sqrt() + 1e-6;
Some([dy / norm, dx / norm])
}
fn our_re_update(
&mut self,
current_xyah: &MeasurementVector,
current_frame: usize,
kf: &KalmanFilter,
) {
let n = self.observations.len();
if n == 0 {
return;
}
let (last_obs, last_frame) = &self.observations[n - 1];
let gap = (current_frame as isize - *last_frame as isize).max(1) as usize;
if gap <= 1 {
return;
}
let last_tlwh = xyah4_to_tlwh(last_obs);
let current_tlwh = xyah4_to_tlwh(current_xyah);
let (mut mean, mut covariance) = kf.initiate(last_obs);
for step in 1..=gap {
let t = step as f32 / gap as f32;
let virtual_tlwh = [
last_tlwh[0] + (current_tlwh[0] - last_tlwh[0]) * t,
last_tlwh[1] + (current_tlwh[1] - last_tlwh[1]) * t,
last_tlwh[2] + (current_tlwh[2] - last_tlwh[2]) * t,
last_tlwh[3] + (current_tlwh[3] - last_tlwh[3]) * t,
];
let virtual_xyah = tlwh_to_xyah(&virtual_tlwh);
let (pm, pc) = kf.predict(&mean, &covariance);
mean = pm;
covariance = pc;
let (um, uc) = kf.update(&mean, &covariance, &virtual_xyah);
mean = um;
covariance = uc;
}
self.mean = mean;
self.covariance = covariance;
self.tlwh = xyah_to_tlwh(&self.mean);
}
fn push_observation(&mut self, xyah: MeasurementVector, frame_id: usize, max_obs: usize) {
self.observations.push((xyah, frame_id));
if self.observations.len() > max_obs {
self.observations.remove(0);
}
}
fn mark_deleted(&mut self) {
self.state = OcSortTrackState::Deleted;
}
fn is_confirmed(&self) -> bool {
self.state == OcSortTrackState::Confirmed
}
}
fn tlwh_to_xyah(tlwh: &[f32; 4]) -> MeasurementVector {
let cx = tlwh[0] + tlwh[2] / 2.0;
let cy = tlwh[1] + tlwh[3] / 2.0;
let a = tlwh[2] / tlwh[3].max(1e-6);
let h = tlwh[3];
MeasurementVector::from_vec(vec![cx, cy, a, h])
}
fn xyah_to_tlwh(state: &StateVector) -> [f32; 4] {
let w = state[2] * state[3];
let h = state[3];
let x = state[0] - w / 2.0;
let y = state[1] - h / 2.0;
[x, y, w, h]
}
fn xyah4_to_tlwh(xyah: &MeasurementVector) -> [f32; 4] {
let w = xyah[2] * xyah[3];
let h = xyah[3];
let x = xyah[0] - w / 2.0;
let y = xyah[1] - h / 2.0;
[x, y, w, h]
}
#[derive(Debug, Clone)]
struct Detection {
tlwh: [f32; 4],
score: f32,
class_id: i64,
}
pub struct OcSort {
tracks: Vec<OcSortTrack>,
max_age: usize,
min_hits: usize,
iou_threshold: f32,
delta_t: usize,
inertia: f32,
kf: KalmanFilter,
next_id: u64,
frame_count: usize,
}
impl OcSort {
pub fn new(
max_age: usize,
min_hits: usize,
iou_threshold: f32,
delta_t: usize,
inertia: f32,
) -> Self {
Self {
tracks: Vec::new(),
max_age,
min_hits,
iou_threshold,
delta_t,
inertia: inertia.clamp(0.0, 1.0),
kf: KalmanFilter::default(),
next_id: 1,
frame_count: 0,
}
}
pub fn update(&mut self, detections: Vec<([f32; 4], f32, i64)>) -> Vec<OcSortTrack> {
self.frame_count += 1;
let detections: Vec<Detection> = detections
.into_iter()
.map(|(tlwh, score, class_id)| Detection {
tlwh,
score,
class_id,
})
.collect();
for track in &mut self.tracks {
track.predict(&self.kf);
}
let (matches, unmatched_dets, unmatched_trks) = self.associate(&detections);
for (det_idx, trk_idx) in &matches {
let det = &detections[*det_idx];
let xyah = tlwh_to_xyah(&det.tlwh);
let track = &mut self.tracks[*trk_idx];
let was_lost = track.time_since_update > 0;
if was_lost {
track.our_re_update(&xyah, self.frame_count, &self.kf);
}
track.update_kf(&xyah, &self.kf);
track.push_observation(xyah, self.frame_count, self.delta_t + 1);
track.tlwh = det.tlwh;
track.score = det.score;
track.class_id = det.class_id;
track.hits += 1;
track.hit_streak += 1;
track.time_since_update = 0;
}
for det_idx in unmatched_dets {
let det = &detections[det_idx];
let track = OcSortTrack::new(
det.tlwh,
det.score,
det.class_id,
self.next_id,
self.frame_count,
&self.kf,
);
self.next_id += 1;
self.tracks.push(track);
}
for track in &mut self.tracks {
if track.time_since_update == 0 && track.hit_streak >= self.min_hits {
track.state = OcSortTrackState::Confirmed;
}
if track.time_since_update > self.max_age {
track.mark_deleted();
}
}
let unmatched_trks_set: HashSet<usize> = unmatched_trks.into_iter().collect();
for (i, track) in self.tracks.iter_mut().enumerate() {
if unmatched_trks_set.contains(&i) && track.state == OcSortTrackState::Tentative {
track.mark_deleted();
}
}
self.tracks.retain(|t| t.state != OcSortTrackState::Deleted);
self.tracks
.iter()
.filter(|t| t.is_confirmed() && t.time_since_update == 0)
.cloned()
.collect()
}
fn associate(&self, detections: &[Detection]) -> (Vec<(usize, usize)>, Vec<usize>, Vec<usize>) {
let n_trks = self.tracks.len();
let n_dets = detections.len();
if n_trks == 0 {
return (Vec::new(), (0..n_dets).collect(), Vec::new());
}
if n_dets == 0 {
return (Vec::new(), Vec::new(), (0..n_trks).collect());
}
let pred_boxes: Vec<[f32; 4]> = self.tracks.iter().map(|t| t.tlwh).collect();
let det_boxes: Vec<[f32; 4]> = detections.iter().map(|d| d.tlwh).collect();
let ious = crate::utils::geometry::iou_batch(&pred_boxes, &det_boxes);
let mut angle_diff: Vec<Vec<f32>> = vec![vec![0.0_f32; n_dets]; n_trks];
for (i, track) in self.tracks.iter().enumerate() {
let vel_dir = match track.obs_direction(self.delta_t) {
Some(v) => v,
None => continue,
};
let (last_xyah, _) = track
.observations
.last()
.expect("observations is non-empty by invariant");
let last_cx = last_xyah[0];
let last_cy = last_xyah[1];
for (j, det) in detections.iter().enumerate() {
let det_cx = det.tlwh[0] + det.tlwh[2] / 2.0;
let det_cy = det.tlwh[1] + det.tlwh[3] / 2.0;
let dy = det_cy - last_cy;
let dx = det_cx - last_cx;
let norm = (dy * dy + dx * dx).sqrt() + 1e-6;
let cand_dy = dy / norm;
let cand_dx = dx / norm;
let dot = (vel_dir[0] * cand_dy + vel_dir[1] * cand_dx).clamp(-1.0, 1.0);
let angle = dot.acos();
let normalized = (std::f32::consts::FRAC_PI_2 - angle.abs()) / std::f32::consts::PI;
angle_diff[i][j] = (normalized * self.inertia * det.score).max(0.0);
}
}
let cost_matrix: Vec<Vec<f32>> = (0..n_trks)
.map(|i| {
(0..n_dets)
.map(|j| 1.0 - (ious[i][j] + angle_diff[i][j]))
.collect()
})
.collect();
let (mut matches, mut unmatched_dets, mut unmatched_trks) =
greedy_match(&cost_matrix, 1.0 - self.iou_threshold);
if !unmatched_dets.is_empty() && !unmatched_trks.is_empty() {
let left_det_boxes: Vec<[f32; 4]> = unmatched_dets
.iter()
.map(|&di| detections[di].tlwh)
.collect();
let left_trk_obs: Vec<[f32; 4]> = unmatched_trks
.iter()
.map(|&ti| {
xyah4_to_tlwh(
&self.tracks[ti]
.observations
.last()
.expect("observations is non-empty by invariant")
.0,
)
})
.collect();
let iou_left = crate::utils::geometry::iou_batch(&left_trk_obs, &left_det_boxes);
let max_iou = iou_left
.iter()
.flat_map(|r| r.iter())
.cloned()
.fold(f32::NEG_INFINITY, f32::max);
if max_iou > self.iou_threshold {
let cost_left: Vec<Vec<f32>> = iou_left
.iter()
.map(|row| row.iter().map(|&v| 1.0 - v).collect())
.collect();
let (r2_matches, r2_ud, r2_ut) = greedy_match(&cost_left, 1.0 - self.iou_threshold);
for (det_local, trk_local) in r2_matches {
matches.push((unmatched_dets[det_local], unmatched_trks[trk_local]));
}
unmatched_dets = r2_ud.into_iter().map(|di| unmatched_dets[di]).collect();
unmatched_trks = r2_ut.into_iter().map(|ti| unmatched_trks[ti]).collect();
}
}
(matches, unmatched_dets, unmatched_trks)
}
}
impl Default for OcSort {
fn default() -> Self {
Self::new(30, 3, 0.3, 3, 0.2)
}
}
fn greedy_match(
cost_matrix: &[Vec<f32>],
threshold: f32,
) -> (Vec<(usize, usize)>, Vec<usize>, Vec<usize>) {
if cost_matrix.is_empty() {
return (Vec::new(), Vec::new(), Vec::new());
}
let rows = cost_matrix.len();
let cols = cost_matrix[0].len();
let mut matches = Vec::new();
let mut unmatched_rows: HashSet<usize> = (0..rows).collect();
let mut unmatched_cols: HashSet<usize> = (0..cols).collect();
let mut costs: Vec<(f32, usize, usize)> = Vec::new();
for (r, row) in cost_matrix.iter().enumerate() {
for (c, &cost) in row.iter().enumerate() {
costs.push((cost, r, c));
}
}
costs.sort_by(|a, b| a.0.total_cmp(&b.0));
for (cost, trk, det) in costs {
if cost > threshold {
break;
}
if unmatched_rows.contains(&trk) && unmatched_cols.contains(&det) {
matches.push((det, trk));
unmatched_rows.remove(&trk);
unmatched_cols.remove(&det);
}
}
(
matches,
unmatched_cols.into_iter().collect(),
unmatched_rows.into_iter().collect(),
)
}
#[cfg(feature = "python")]
use pyo3::prelude::*;
#[cfg(feature = "python")]
type PyTrackingResult = (u64, [f32; 4], f32, i64);
#[cfg(feature = "python")]
#[pyclass(name = "OCSORT")]
pub struct PyOcSort {
inner: OcSort,
}
#[cfg(feature = "python")]
#[pymethods]
impl PyOcSort {
#[new]
#[pyo3(signature = (max_age=30, min_hits=3, iou_threshold=0.3, delta_t=3, inertia=0.2))]
fn new(
max_age: usize,
min_hits: usize,
iou_threshold: f32,
delta_t: usize,
inertia: f32,
) -> Self {
Self {
inner: OcSort::new(max_age, min_hits, iou_threshold, delta_t, inertia),
}
}
fn update(&mut self, detections: Vec<([f32; 4], f32, i64)>) -> PyResult<Vec<PyTrackingResult>> {
let tracks = self.inner.update(detections);
Ok(tracks
.into_iter()
.map(|t| (t.track_id, t.tlwh, t.score, t.class_id))
.collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn det(x: f32, y: f32, w: f32, h: f32, s: f32) -> ([f32; 4], f32, i64) {
([x, y, w, h], s, 0)
}
#[test]
fn test_ocsort_empty_detections() {
let mut tracker = OcSort::new(30, 1, 0.3, 3, 0.2);
let tracks = tracker.update(vec![]);
assert!(tracks.is_empty());
}
#[test]
fn test_ocsort_single_detection_confirmed_after_min_hits() {
let mut tracker = OcSort::new(30, 3, 0.3, 3, 0.2);
for _ in 0..3 {
tracker.update(vec![det(100.0, 100.0, 50.0, 100.0, 0.9)]);
}
let tracks = tracker.update(vec![det(100.0, 100.0, 50.0, 100.0, 0.9)]);
assert_eq!(tracks.len(), 1);
assert_eq!(tracks[0].track_id, 1);
}
#[test]
fn test_ocsort_min_hits_one() {
let mut tracker = OcSort::new(30, 1, 0.3, 3, 0.2);
let tracks = tracker.update(vec![det(100.0, 100.0, 50.0, 100.0, 0.9)]);
assert_eq!(tracks.len(), 1);
}
#[test]
fn test_ocsort_track_deleted_after_max_age() {
let mut tracker = OcSort::new(2, 1, 0.3, 3, 0.2);
tracker.update(vec![det(100.0, 100.0, 50.0, 100.0, 0.9)]);
for _ in 0..3 {
tracker.update(vec![]);
}
let tracks = tracker.update(vec![]);
assert!(tracks.is_empty());
}
#[test]
fn test_ocsort_two_objects_separate_ids() {
let mut tracker = OcSort::new(30, 1, 0.3, 3, 0.2);
let tracks = tracker.update(vec![
det(100.0, 100.0, 50.0, 100.0, 0.9),
det(400.0, 400.0, 50.0, 100.0, 0.85),
]);
assert_eq!(tracks.len(), 2);
let ids: std::collections::HashSet<u64> = tracks.iter().map(|t| t.track_id).collect();
assert_eq!(ids.len(), 2);
}
#[test]
fn test_ocsort_ids_sequential() {
let mut tracker = OcSort::new(30, 1, 0.3, 3, 0.2);
tracker.update(vec![det(100.0, 100.0, 50.0, 100.0, 0.9)]);
let tracks = tracker.update(vec![
det(100.0, 100.0, 50.0, 100.0, 0.9),
det(400.0, 400.0, 50.0, 100.0, 0.85),
]);
let mut ids: Vec<u64> = tracks.iter().map(|t| t.track_id).collect();
ids.sort();
assert_eq!(ids[0], 1);
}
#[test]
fn test_ocsort_ocv_velocity_computed() {
let kf = KalmanFilter::default();
let tlwh = [100.0_f32, 100.0, 50.0, 100.0];
let mut track = OcSortTrack::new(tlwh, 0.9, 0, 1, 1, &kf);
assert!(track.obs_direction(3).is_none());
let tlwh2 = [110.0_f32, 100.0, 50.0, 100.0];
let xyah2 = tlwh_to_xyah(&tlwh2);
track.push_observation(xyah2, 2, 4);
let dir = track.obs_direction(3);
assert!(dir.is_some());
let d = dir.unwrap();
assert!(
d[0].abs() < 0.01,
"Expected dy direction ~0.0, got {}",
d[0]
);
assert!(
(d[1] - 1.0).abs() < 0.01,
"Expected dx direction ~1.0, got {}",
d[1]
);
}
#[test]
fn test_ocsort_our_does_not_panic_on_re_association() {
let kf = KalmanFilter::default();
let tlwh = [100.0_f32, 100.0, 50.0, 100.0];
let mut track = OcSortTrack::new(tlwh, 0.9, 0, 1, 1, &kf);
for _ in 0..5 {
track.predict(&kf);
}
let xyah_new = tlwh_to_xyah(&[130.0_f32, 100.0, 50.0, 100.0]);
track.our_re_update(&xyah_new, 7, &kf);
assert!(track.tlwh.iter().all(|v| v.is_finite()));
}
#[test]
fn test_ocsort_default_params() {
let tracker = OcSort::default();
assert_eq!(tracker.max_age, 30);
assert_eq!(tracker.min_hits, 3);
assert!((tracker.iou_threshold - 0.3).abs() < 1e-5);
assert_eq!(tracker.delta_t, 3);
assert!((tracker.inertia - 0.2).abs() < 1e-5);
}
#[test]
fn test_ocsort_instance_isolation() {
let mut a = OcSort::new(30, 1, 0.3, 3, 0.2);
let mut b = OcSort::new(30, 1, 0.3, 3, 0.2);
a.update(vec![det(100.0, 100.0, 50.0, 100.0, 0.9)]);
let tracks_b = b.update(vec![det(200.0, 200.0, 50.0, 100.0, 0.9)]);
assert_eq!(tracks_b[0].track_id, 1);
assert_eq!(a.frame_count, 1);
assert_eq!(b.frame_count, 1);
}
#[test]
fn test_ocsort_hit_streak_resets_on_miss() {
let mut tracker = OcSort::new(30, 1, 0.3, 3, 0.2);
tracker.update(vec![det(100.0, 100.0, 50.0, 100.0, 0.9)]);
tracker.update(vec![]);
let tracks = tracker.update(vec![det(100.0, 100.0, 50.0, 100.0, 0.9)]);
assert_eq!(tracks.len(), 1);
assert_eq!(tracks[0].hit_streak, 1);
}
#[test]
fn test_ocsort_second_round_rematches_after_gap() {
let mut tracker = OcSort::new(5, 1, 0.3, 3, 0.2);
for _ in 0..2 {
tracker.update(vec![det(100.0, 100.0, 50.0, 100.0, 0.9)]);
}
tracker.update(vec![]);
let tracks = tracker.update(vec![det(100.0, 100.0, 50.0, 100.0, 0.9)]);
assert_eq!(tracks.len(), 1, "Track should be re-matched after gap");
assert_eq!(tracks[0].track_id, 1, "Should be the same track");
}
#[test]
fn test_ocsort_ocm_direction_bonus_path() {
let mut tracker = OcSort::new(30, 1, 0.3, 3, 0.2);
tracker.update(vec![det(100.0, 100.0, 50.0, 100.0, 0.9)]); tracker.update(vec![det(100.0, 100.0, 50.0, 100.0, 0.9)]); let tracks = tracker.update(vec![det(105.0, 100.0, 50.0, 100.0, 0.9)]);
assert_eq!(tracks.len(), 1);
}
#[test]
fn test_ocsort_round2_observations_last_reached() {
let mut tracker = OcSort::new(30, 1, 0.3, 3, 0.2);
tracker.update(vec![det(0.0, 0.0, 50.0, 100.0, 0.9)]); tracker.update(vec![det(0.0, 0.0, 50.0, 100.0, 0.9)]); tracker.update(vec![det(10000.0, 0.0, 50.0, 100.0, 0.9)]);
}
}