from abc import abstractmethod
from dataclasses import asdict
from typing import Dict, List, Tuple, Union
import numpy as np
from similari import (
Sort as SortImpl,
VisualSort as VisualSortImpl,
BoundingBox,
SpatioTemporalConstraints,
PositionalMetricType,
)
from .config import (
OriginalSortParams,
SortParams,
VisualSortParams,
PositionalMetricType as PositionalMetricConfigType,
)
from .original_sort import Sort as OriginalSortImpl
class Tracker:
@abstractmethod
def process_frame(
self, frame_num: int, detections: List[Tuple[float, float, float, float, float]]
) -> List[Tuple[int, float, float, float, float, float]]:
pass
class OriginalSort(Tracker):
def __init__(self, params: OriginalSortParams):
self._tracker = OriginalSortImpl(**asdict(params))
def process_frame(
self, frame_num: int, detections: List[Tuple[float, float, float, float, float]]
) -> List[Tuple[int, float, float, float, float, float]]:
np_detections = np.array(detections)
np_detections[:, 2:4] += np_detections[:, 0:2]
tracks = self._tracker.update(np_detections)
return [
(
int(track[4]),
track[0],
track[1],
track[2] - track[0],
track[3] - track[1],
1.0,
)
for track in tracks
]
class SimilariTracker(Tracker):
def __init__(self, params: Union[SortParams, VisualSortParams]):
constraints = None
if params.spatio_temporal_constraints:
constraints = SpatioTemporalConstraints()
constraints.add_constraints(
list(map(tuple, params.spatio_temporal_constraints))
)
positional_metric = None
if params.positional_metric:
if params.positional_metric.type == PositionalMetricConfigType.IoU:
positional_metric = PositionalMetricType.iou(
threshold=params.positional_metric.threshold
)
else:
positional_metric = PositionalMetricType.maha()
if isinstance(params, SortParams):
self._tracker = SortImpl(
shards=params.shards,
bbox_history=params.bbox_history,
max_idle_epochs=params.max_idle_epochs,
method=positional_metric,
spatio_temporal_constraints=constraints,
)
else:
raise NotImplementedError
self._use_confidence = params.use_confidence
self._track_id_map: Dict[int, int] = {}
def process_frame(
self, frame_num: int, detections: List[Tuple[float, float, float, float, float]]
) -> List[Tuple[int, float, float, float, float, float]]:
if self._use_confidence:
dets = [
(BoundingBox.new_with_confidence(*detection).as_xyaah(), 0)
for detection in detections
]
else:
dets = [
(BoundingBox(*detection[:-1]).as_xyaah(), 0) for detection in detections
]
tracks = self._tracker.predict(dets)
rows = []
for track in tracks:
track_id = track.id
_bbox = track.predicted_bbox.as_ltwh()
if track_id not in self._track_id_map:
self._track_id_map[track_id] = len(self._track_id_map) + 1
rows.append(
(
self._track_id_map[track_id],
_bbox.left,
_bbox.top,
_bbox.width,
_bbox.height,
1.0,
)
)
return rows