import numpy as np
from similari import (
BatchVisualSort,
SpatioTemporalConstraints,
PositionalMetricType,
VisualSortOptions,
VisualSortMetricType,
BoundingBox, VisualSortObservation, VisualSortObservationSet,
VisualSortPredictionBatchRequest
)
def get_opts():
constraints = SpatioTemporalConstraints()
constraints.add_constraints([(1, 1.0)])
opts = VisualSortOptions()
opts.spatio_temporal_constraints(constraints)
opts.max_idle_epochs(15)
opts.kept_history_length(25)
opts.max_idle_epochs(15)
opts.kept_history_length(25)
opts.visual_metric(VisualSortMetricType.euclidean(0.7))
opts.positional_metric(PositionalMetricType.maha())
opts.visual_minimal_track_length(7)
opts.visual_minimal_area(5.0)
opts.visual_minimal_quality_use(0.45)
opts.visual_minimal_quality_collect(0.5)
opts.visual_max_observations(25)
opts.visual_min_votes(5)
return opts
def build_observation(obj):
bbox = BoundingBox(*obj[:4]).as_xyaah()
feature = obj[4:132]
feature_quality = obj[132]
return VisualSortObservation(
feature=feature,
feature_quality=feature_quality,
bounding_box=bbox,
custom_object_id=None,
)
def generate_objs(n_objs):
return np.random.rand(n_objs, 4 + 128 + 1)
def build_prediction_request(objs, n_batches):
batch_request = VisualSortPredictionBatchRequest()
for batch_i, batch_objs in enumerate(np.split(objs, n_batches)):
for obj in batch_objs:
batch_request.add(batch_i, build_observation(obj))
return batch_request
def main(n_frames=10, n_objs=6, n_batches=2):
assert n_objs % n_batches == 0, 'For simplicity, batches of equal size are expected.'
tracker = BatchVisualSort(distance_shards=1, voting_shards=1, opts=get_opts())
objs = generate_objs(n_objs)
for frame_i in range(n_frames):
print(f'======== {frame_i} ========')
batch_request = build_prediction_request(objs, n_batches)
result = tracker.predict(batch_request)
for _ in range(result.batch_size()):
scene_id, tracks = result.get()
print("Scene", scene_id)
for track in tracks:
print(track)
track = tracks[0]
bbox = track.predicted_bbox.as_ltwh()
print((
track.id,
bbox.left,
bbox.top,
bbox.width,
bbox.height,
bbox.confidence,
))
print('++++ Wasted ++++')
wasted = tracker.wasted()
for w in wasted:
print(w)
if __name__ == '__main__':
main()