similari 0.24.2

Machine learning framework for building object trackers and similarity search engines
Documentation
"""Unified tracker interface for supported trackers."""
from abc import abstractmethod
from dataclasses import asdict
from typing import Dict, List, Tuple, Union
import numpy as np

from similari import (
    Sort as SortImpl,
    VisualSort as VisualSortImpl,
    BoundingBox,
    SpatioTemporalConstraints,
    PositionalMetricType,
)
from .config import (
    OriginalSortParams,
    SortParams,
    VisualSortParams,
    PositionalMetricType as PositionalMetricConfigType,
)
from .original_sort import Sort as OriginalSortImpl


class Tracker:
    @abstractmethod
    def process_frame(
        self, frame_num: int, detections: List[Tuple[float, float, float, float, float]]
    ) -> List[Tuple[int, float, float, float, float, float]]:
        """(left, top, width, height, confidence) =>
        (track_id, left, top, width, height, confidence)
        """
        pass


class OriginalSort(Tracker):
    def __init__(self, params: OriginalSortParams):
        self._tracker = OriginalSortImpl(**asdict(params))

    def process_frame(
        self, frame_num: int, detections: List[Tuple[float, float, float, float, float]]
    ) -> List[Tuple[int, float, float, float, float, float]]:
        # tuple(top, left, width, height) to np.array([x1, y1, x2, y2])
        np_detections = np.array(detections)
        np_detections[:, 2:4] += np_detections[:, 0:2]
        tracks = self._tracker.update(np_detections)
        return [
            (
                int(track[4]),
                track[0],
                track[1],
                track[2] - track[0],
                track[3] - track[1],
                1.0,
            )
            for track in tracks
        ]


class SimilariTracker(Tracker):
    def __init__(self, params: Union[SortParams, VisualSortParams]):
        constraints = None
        if params.spatio_temporal_constraints:
            constraints = SpatioTemporalConstraints()
            constraints.add_constraints(
                list(map(tuple, params.spatio_temporal_constraints))
            )

        positional_metric = None
        if params.positional_metric:
            if params.positional_metric.type == PositionalMetricConfigType.IoU:
                positional_metric = PositionalMetricType.iou(
                    threshold=params.positional_metric.threshold
                )
            else:
                positional_metric = PositionalMetricType.maha()

        if isinstance(params, SortParams):
            self._tracker = SortImpl(
                shards=params.shards,
                bbox_history=params.bbox_history,
                max_idle_epochs=params.max_idle_epochs,
                method=positional_metric,
                spatio_temporal_constraints=constraints,
            )
        else:
            raise NotImplementedError

        self._use_confidence = params.use_confidence
        self._track_id_map: Dict[int, int] = {}  # to have 1-based track id

    def process_frame(
        self, frame_num: int, detections: List[Tuple[float, float, float, float, float]]
    ) -> List[Tuple[int, float, float, float, float, float]]:
        if self._use_confidence:
            dets = [
                (BoundingBox.new_with_confidence(*detection).as_xyaah(), 0)
                for detection in detections
            ]
        else:
            dets = [
                (BoundingBox(*detection[:-1]).as_xyaah(), 0) for detection in detections
            ]

        tracks = self._tracker.predict(dets)

        rows = []
        for track in tracks:
            track_id = track.id
            _bbox = track.predicted_bbox.as_ltwh()
            if track_id not in self._track_id_map:
                self._track_id_map[track_id] = len(self._track_id_map) + 1
            rows.append(
                (
                    self._track_id_map[track_id],
                    _bbox.left,
                    _bbox.top,
                    _bbox.width,
                    _bbox.height,
                    1.0,
                )
            )

        # TODO
        # if frame_num % ...:
        #     self._tracker.wasted()

        return rows