from dataclasses import dataclass, asdict
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from omegaconf import OmegaConf, MISSING
class TrackerType(Enum):
OriginalSort = 0
Sort = 1
@dataclass
class Tracker:
type: TrackerType
params: Dict[str, Any]
@dataclass
class OriginalSortParams:
max_age: int = 1
min_hits: int = 3
iou_threshold: float = 0.3
class PositionalMetricType(Enum):
IoU = 0
Maha = 1
@dataclass
class PositionalMetric:
type: PositionalMetricType
threshold: float = 0.3
@dataclass
class SortParams:
shards: int = 4
bbox_history: int = 10
max_idle_epochs: int = 10
positional_metric: PositionalMetric = PositionalMetric(PositionalMetricType.IoU)
spatio_temporal_constraints: Optional[List[List]] = None
use_confidence: bool = False
@dataclass
class VisualSortParams:
@dataclass
class Evaluator:
num_cores: int = 1
@dataclass
class ConfigSchema:
name: str
data_path: str
output_path: str
tracker: Tracker
evaluator: Evaluator = Evaluator()
@dataclass
class Config(ConfigSchema):
tracker: Union[OriginalSortParams, SortParams, VisualSortParams] = MISSING
class ConfigException(Exception):
def load_config(config_file_path: str) -> Config:
config = OmegaConf.unsafe_merge(ConfigSchema, OmegaConf.load(config_file_path))
tracker_params_schema = SortParams
if config.tracker.type == TrackerType.OriginalSort:
tracker_params_schema = OriginalSortParams
tracker_params = OmegaConf.to_object(
OmegaConf.unsafe_merge(
tracker_params_schema, config.tracker.params
)
)
print(f'Configuration:\n{OmegaConf.to_yaml(config)}')
return Config(
name=config.name,
data_path=config.data_path,
output_path=config.output_path,
tracker=tracker_params,
evaluator=config.evaluator
)