megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import os
import pickle
import tarfile
from typing import Tuple

import numpy as np

from ....logger import get_logger
from .meta_vision import VisionDataset
from .utils import _default_dataset_root, load_raw_data_from_url

logger = get_logger(__name__)


class CIFAR10(VisionDataset):
    r""":class:`~.Dataset` for CIFAR10 meta data."""

    url_path = "http://www.cs.utoronto.ca/~kriz/"
    raw_file_name = "cifar-10-python.tar.gz"
    raw_file_md5 = "c58f30108f718f92721af3b95e74349a"
    raw_file_dir = "cifar-10-batches-py"
    train_batch = [
        "data_batch_1",
        "data_batch_2",
        "data_batch_3",
        "data_batch_4",
        "data_batch_5",
    ]
    test_batch = ["test_batch"]
    meta_info = {"name": "batches.meta"}

    def __init__(
        self,
        root: str = None,
        train: bool = True,
        download: bool = True,
        timeout: int = 500,
    ):
        super().__init__(root, order=("image", "image_category"))

        self.timeout = timeout

        # process the root path
        if root is None:
            self.root = self._default_root
            if not os.path.exists(self.root):
                os.makedirs(self.root)
        else:
            self.root = root
            if not os.path.exists(self.root):
                if download:
                    logger.debug(
                        "dir %s does not exist, will be automatically created",
                        self.root,
                    )
                    os.makedirs(self.root)
                else:
                    raise ValueError("dir %s does not exist" % self.root)

        self.target_file = os.path.join(self.root, self.raw_file_dir)

        # check existence of target pickle dir, if exists load the
        # pickle file no matter what download is set
        if os.path.exists(self.target_file):
            if train:
                self.arrays = self.bytes2array(self.train_batch)
            else:
                self.arrays = self.bytes2array(self.test_batch)
        else:
            if download:
                self.download()
                if train:
                    self.arrays = self.bytes2array(self.train_batch)
                else:
                    self.arrays = self.bytes2array(self.test_batch)
            else:
                raise ValueError(
                    "dir does not contain target file %s, please set download=True"
                    % (self.target_file)
                )

    def __getitem__(self, index: int) -> Tuple:
        return tuple(array[index] for array in self.arrays)

    def __len__(self) -> int:
        return len(self.arrays[0])

    @property
    def _default_root(self):
        return os.path.join(_default_dataset_root(), self.__class__.__name__)

    @property
    def meta(self):
        meta_path = os.path.join(self.root, self.raw_file_dir, self.meta_info["name"])
        with open(meta_path, "rb") as f:
            meta = pickle.load(f, encoding="bytes")
        return meta

    def download(self):
        url = self.url_path + self.raw_file_name
        load_raw_data_from_url(url, self.raw_file_name, self.raw_file_md5, self.root)
        self.process()

    def untar(self, file_path, dirs):
        assert file_path.endswith(".tar.gz")
        logger.debug("untar file %s to %s", file_path, dirs)
        t = tarfile.open(file_path)
        t.extractall(path=dirs)

    def bytes2array(self, filenames):
        data = []
        label = []
        for filename in filenames:
            path = os.path.join(self.root, self.raw_file_dir, filename)
            logger.debug("unpickle file %s", path)
            with open(path, "rb") as fo:
                dic = pickle.load(fo, encoding="bytes")
                batch_data = dic[b"data"].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1))
                data.extend(list(batch_data[..., [2, 1, 0]]))
                label.extend(dic[b"labels"])
        label = np.array(label, dtype=np.int32)
        return (data, label)

    def process(self):
        logger.info("process raw data ...")
        self.untar(os.path.join(self.root, self.raw_file_name), self.root)


class CIFAR100(CIFAR10):
    r""":class:`~.Dataset` for CIFAR100 meta data."""

    url_path = "http://www.cs.utoronto.ca/~kriz/"
    raw_file_name = "cifar-100-python.tar.gz"
    raw_file_md5 = "eb9058c3a382ffc7106e4002c42a8d85"
    raw_file_dir = "cifar-100-python"
    train_batch = ["train"]
    test_batch = ["test"]
    meta_info = {"name": "meta"}

    @property
    def meta(self):
        meta_path = os.path.join(self.root, self.raw_file_dir, self.meta_info["name"])
        with open(meta_path, "rb") as f:
            meta = pickle.load(f, encoding="bytes")
        return meta

    def bytes2array(self, filenames):
        data = []
        fine_label = []
        coarse_label = []
        for filename in filenames:
            path = os.path.join(self.root, self.raw_file_dir, filename)
            logger.debug("unpickle file %s", path)
            with open(path, "rb") as fo:
                dic = pickle.load(fo, encoding="bytes")
                batch_data = dic[b"data"].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1))
                data.extend(list(batch_data[..., [2, 1, 0]]))
                fine_label.extend(dic[b"fine_labels"])
                coarse_label.extend(dic[b"coarse_labels"])
        fine_label = np.array(fine_label, dtype=np.int32)
        coarse_label = np.array(coarse_label, dtype=np.int32)
        return data, fine_label, coarse_label