import json
import os
from collections import defaultdict
import cv2
import numpy as np
from .meta_vision import VisionDataset
min_keypoints_per_image = 10
def _count_visible_keypoints(anno):
return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
def has_valid_annotation(anno, order):
if len(anno) == 0:
return False
if "boxes" in order or "boxes_category" in order:
if "bbox" not in anno[0]:
return False
if "keypoints" in order:
if "keypoints" not in anno[0]:
return False
if _count_visible_keypoints(anno) < min_keypoints_per_image:
return False
return True
class COCO(VisionDataset):
supported_order = (
"image",
"boxes",
"boxes_category",
"keypoints",
"info",
)
def __init__(
self, root, ann_file, remove_images_without_annotations=False, *, order=None
):
super().__init__(root, order=order, supported_order=self.supported_order)
with open(ann_file, "r") as f:
dataset = json.load(f)
self.imgs = dict()
for img in dataset["images"]:
if "license" in img:
del img["license"]
if "coco_url" in img:
del img["coco_url"]
if "date_captured" in img:
del img["date_captured"]
if "flickr_url" in img:
del img["flickr_url"]
self.imgs[img["id"]] = img
self.img_to_anns = defaultdict(list)
for ann in dataset["annotations"]:
if (
"boxes" not in self.order
and "boxes_category" not in self.order
and "bbox" in ann
):
del ann["bbox"]
if "polygons" not in self.order and "segmentation" in ann:
del ann["segmentation"]
self.img_to_anns[ann["image_id"]].append(ann)
self.cats = dict()
for cat in dataset["categories"]:
self.cats[cat["id"]] = cat
self.ids = list(sorted(self.imgs.keys()))
if remove_images_without_annotations:
ids = []
for img_id in self.ids:
anno = self.img_to_anns[img_id]
anno = [obj for obj in anno if obj["iscrowd"] == 0]
anno = [
obj for obj in anno if obj["bbox"][2] > 0 and obj["bbox"][3] > 0
]
if has_valid_annotation(anno, order):
ids.append(img_id)
self.img_to_anns[img_id] = anno
else:
del self.imgs[img_id]
del self.img_to_anns[img_id]
self.ids = ids
self.json_category_id_to_contiguous_id = {
v: i + 1 for i, v in enumerate(sorted(self.cats.keys()))
}
self.contiguous_category_id_to_json_id = {
v: k for k, v in self.json_category_id_to_contiguous_id.items()
}
def __getitem__(self, index):
img_id = self.ids[index]
anno = self.img_to_anns[img_id]
target = []
for k in self.order:
if k == "image":
file_name = self.imgs[img_id]["file_name"]
path = os.path.join(self.root, file_name)
image = cv2.imread(path, cv2.IMREAD_COLOR)
target.append(image)
elif k == "boxes":
boxes = [obj["bbox"] for obj in anno]
boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4)
boxes[:, 2:] += boxes[:, :2]
target.append(boxes)
elif k == "boxes_category":
boxes_category = [obj["category_id"] for obj in anno]
boxes_category = [
self.json_category_id_to_contiguous_id[c] for c in boxes_category
]
boxes_category = np.array(boxes_category, dtype=np.int32)
target.append(boxes_category)
elif k == "keypoints":
keypoints = [obj["keypoints"] for obj in anno]
keypoints = np.array(keypoints, dtype=np.float32).reshape(
-1, len(self.keypoint_names), 3
)
target.append(keypoints)
elif k == "polygons":
polygons = [obj["segmentation"] for obj in anno]
polygons = [
[np.array(p, dtype=np.float32).reshape(-1, 2) for p in ps]
for ps in polygons
]
target.append(polygons)
elif k == "info":
info = self.imgs[img_id]
info = [info["height"], info["width"], info["file_name"]]
target.append(info)
else:
raise NotImplementedError
return tuple(target)
def __len__(self):
return len(self.ids)
def get_img_info(self, index):
img_id = self.ids[index]
img_info = self.imgs[img_id]
return img_info
class_names = (
"person",
"bicycle",
"car",
"motorcycle",
"airplane",
"bus",
"train",
"truck",
"boat",
"traffic light",
"fire hydrant",
"stop sign",
"parking meter",
"bench",
"bird",
"cat",
"dog",
"horse",
"sheep",
"cow",
"elephant",
"bear",
"zebra",
"giraffe",
"backpack",
"umbrella",
"handbag",
"tie",
"suitcase",
"frisbee",
"skis",
"snowboard",
"sports ball",
"kite",
"baseball bat",
"baseball glove",
"skateboard",
"surfboard",
"tennis racket",
"bottle",
"wine glass",
"cup",
"fork",
"knife",
"spoon",
"bowl",
"banana",
"apple",
"sandwich",
"orange",
"broccoli",
"carrot",
"hot dog",
"pizza",
"donut",
"cake",
"chair",
"couch",
"potted plant",
"bed",
"dining table",
"toilet",
"tv",
"laptop",
"mouse",
"remote",
"keyboard",
"cell phone",
"microwave",
"oven",
"toaster",
"sink",
"refrigerator",
"book",
"clock",
"vase",
"scissors",
"teddy bear",
"hair drier",
"toothbrush",
)
classes_originID = {
"person": 1,
"bicycle": 2,
"car": 3,
"motorcycle": 4,
"airplane": 5,
"bus": 6,
"train": 7,
"truck": 8,
"boat": 9,
"traffic light": 10,
"fire hydrant": 11,
"stop sign": 13,
"parking meter": 14,
"bench": 15,
"bird": 16,
"cat": 17,
"dog": 18,
"horse": 19,
"sheep": 20,
"cow": 21,
"elephant": 22,
"bear": 23,
"zebra": 24,
"giraffe": 25,
"backpack": 27,
"umbrella": 28,
"handbag": 31,
"tie": 32,
"suitcase": 33,
"frisbee": 34,
"skis": 35,
"snowboard": 36,
"sports ball": 37,
"kite": 38,
"baseball bat": 39,
"baseball glove": 40,
"skateboard": 41,
"surfboard": 42,
"tennis racket": 43,
"bottle": 44,
"wine glass": 46,
"cup": 47,
"fork": 48,
"knife": 49,
"spoon": 50,
"bowl": 51,
"banana": 52,
"apple": 53,
"sandwich": 54,
"orange": 55,
"broccoli": 56,
"carrot": 57,
"hot dog": 58,
"pizza": 59,
"donut": 60,
"cake": 61,
"chair": 62,
"couch": 63,
"potted plant": 64,
"bed": 65,
"dining table": 67,
"toilet": 70,
"tv": 72,
"laptop": 73,
"mouse": 74,
"remote": 75,
"keyboard": 76,
"cell phone": 77,
"microwave": 78,
"oven": 79,
"toaster": 80,
"sink": 81,
"refrigerator": 82,
"book": 84,
"clock": 85,
"vase": 86,
"scissors": 87,
"teddy bear": 88,
"hair drier": 89,
"toothbrush": 90,
}
keypoint_names = (
"nose",
"left_eye",
"right_eye",
"left_ear",
"right_ear",
"left_shoulder",
"right_shoulder",
"left_elbow",
"right_elbow",
"left_wrist",
"right_wrist",
"left_hip",
"right_hip",
"left_knee",
"right_knee",
"left_ankle",
"right_ankle",
)