megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# ---------------------------------------------------------------------
# Part of the following code in this file refs to torchvision
# BSD 3-Clause License
#
# Copyright (c) Soumith Chintala 2016,
# All rights reserved.
# ---------------------------------------------------------------------
import collections.abc
import os
import xml.etree.ElementTree as ET

import cv2
import numpy as np

from .meta_vision import VisionDataset


class PascalVOC(VisionDataset):
    r"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Dataset."""

    supported_order = (
        "image",
        "boxes",
        "boxes_category",
        "mask",
        "info",
    )

    def __init__(self, root, image_set, *, order=None):
        if ("boxes" in order or "boxes_category" in order) and "mask" in order:
            raise ValueError(
                "PascalVOC only supports boxes & boxes_category or mask, not both."
            )

        super().__init__(root, order=order, supported_order=self.supported_order)

        if not os.path.isdir(self.root):
            raise RuntimeError("Dataset not found or corrupted.")

        self.image_set = image_set
        image_dir = os.path.join(self.root, "JPEGImages")

        if "boxes" in order or "boxes_category" in order:
            annotation_dir = os.path.join(self.root, "Annotations")
            splitdet_dir = os.path.join(self.root, "ImageSets/Main")
            split_f = os.path.join(splitdet_dir, image_set.rstrip("\n") + ".txt")
            with open(os.path.join(split_f), "r") as f:
                self.file_names = [x.strip() for x in f.readlines()]
            self.images = [os.path.join(image_dir, x + ".jpg") for x in self.file_names]
            self.annotations = [
                os.path.join(annotation_dir, x + ".xml") for x in self.file_names
            ]
            assert len(self.images) == len(self.annotations)
        elif "mask" in order:
            if "aug" in image_set:
                mask_dir = os.path.join(self.root, "SegmentationClass_aug")
            else:
                mask_dir = os.path.join(self.root, "SegmentationClass")
            splitmask_dir = os.path.join(self.root, "ImageSets/Segmentation")
            split_f = os.path.join(splitmask_dir, image_set.rstrip("\n") + ".txt")
            with open(os.path.join(split_f), "r") as f:
                self.file_names = [x.strip() for x in f.readlines()]
            self.images = [os.path.join(image_dir, x + ".jpg") for x in self.file_names]
            self.masks = [os.path.join(mask_dir, x + ".png") for x in self.file_names]
            assert len(self.images) == len(self.masks)
        else:
            raise NotImplementedError

        self.img_infos = dict()

    def __getitem__(self, index):
        target = []
        for k in self.order:
            if k == "image":
                image = cv2.imread(self.images[index], cv2.IMREAD_COLOR)
                target.append(image)
            elif k == "boxes":
                anno = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot())
                boxes = [obj["bndbox"] for obj in anno["annotation"]["object"]]
                # boxes type xyxy
                boxes = [
                    (bb["xmin"], bb["ymin"], bb["xmax"], bb["ymax"]) for bb in boxes
                ]
                boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4)
                target.append(boxes)
            elif k == "boxes_category":
                anno = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot())
                boxes_category = [obj["name"] for obj in anno["annotation"]["object"]]
                boxes_category = [
                    self.class_names.index(bc) + 1 for bc in boxes_category
                ]
                boxes_category = np.array(boxes_category, dtype=np.int32)
                target.append(boxes_category)
            elif k == "mask":
                if "aug" in self.image_set:
                    mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE)
                else:
                    mask = cv2.imread(self.masks[index], cv2.IMREAD_COLOR)
                    mask = self._trans_mask(mask)
                mask = mask[:, :, np.newaxis]
                target.append(mask)
            elif k == "info":
                info = self.get_img_info(index, image)
                info = [info["height"], info["width"], info["file_name"]]
                target.append(info)
            else:
                raise NotImplementedError

        return tuple(target)

    def __len__(self):
        return len(self.images)

    def get_img_info(self, index, image=None):
        if index not in self.img_infos:
            if image is None:
                image = cv2.imread(self.images[index], cv2.IMREAD_COLOR)
            self.img_infos[index] = dict(
                height=image.shape[0],
                width=image.shape[1],
                file_name=self.file_names[index],
            )
        return self.img_infos[index]

    def _trans_mask(self, mask):
        label = np.ones(mask.shape[:2]) * 255
        for i in range(len(self.class_colors)):
            b, g, r = self.class_colors[i]
            label[
                (mask[:, :, 0] == b) & (mask[:, :, 1] == g) & (mask[:, :, 2] == r)
            ] = i
        return label.astype(np.uint8)

    def parse_voc_xml(self, node):
        voc_dict = {}
        children = list(node)
        if children:
            def_dic = collections.defaultdict(list)
            for dc in map(self.parse_voc_xml, children):
                for ind, v in dc.items():
                    def_dic[ind].append(v)
            if node.tag == "annotation":
                def_dic["object"] = [def_dic["object"]]
            voc_dict = {
                node.tag: {
                    ind: v[0] if len(v) == 1 else v for ind, v in def_dic.items()
                }
            }
        if node.text:
            text = node.text.strip()
            if not children:
                voc_dict[node.tag] = text
        return voc_dict

    class_names = (
        "aeroplane",
        "bicycle",
        "bird",
        "boat",
        "bottle",
        "bus",
        "car",
        "cat",
        "chair",
        "cow",
        "diningtable",
        "dog",
        "horse",
        "motorbike",
        "person",
        "pottedplant",
        "sheep",
        "sofa",
        "train",
        "tvmonitor",
    )

    class_colors = [
        [0, 0, 0],  # background
        [0, 0, 128],
        [0, 128, 0],
        [0, 128, 128],
        [128, 0, 0],
        [128, 0, 128],
        [128, 128, 0],
        [128, 128, 128],
        [0, 0, 64],
        [0, 0, 192],
        [0, 128, 64],
        [0, 128, 192],
        [128, 0, 64],
        [128, 0, 192],
        [128, 128, 64],
        [128, 128, 192],
        [0, 64, 0],
        [0, 64, 128],
        [0, 192, 0],
        [0, 192, 128],
        [128, 64, 0],
    ]