megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import hashlib
import os
import tarfile

from ....distributed.group import is_distributed
from ....logger import get_logger
from ....utils.http_download import download_from_url

IMG_EXT = (".jpg", ".png", ".jpeg", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")

logger = get_logger(__name__)


def _default_dataset_root():
    default_dataset_root = os.path.expanduser(
        os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "megengine")
    )

    return default_dataset_root


def load_raw_data_from_url(url: str, filename: str, target_md5: str, raw_data_dir: str):
    cached_file = os.path.join(raw_data_dir, filename)
    logger.debug(
        "load_raw_data_from_url: downloading to or using cached %s ...", cached_file
    )
    if not os.path.exists(cached_file):
        if is_distributed():
            logger.warning(
                "Downloading raw data in DISTRIBUTED mode\n"
                "    File may be downloaded multiple times. We recommend\n"
                "    users to download in single process first."
            )
        md5 = download_from_url(url, cached_file)
    else:
        md5 = calculate_md5(cached_file)
    if target_md5 == md5:
        logger.debug("%s exists with correct md5: %s", filename, target_md5)
    else:
        os.remove(cached_file)
        raise RuntimeError("{} exists but fail to match md5".format(filename))


def calculate_md5(filename):
    m = hashlib.md5()
    with open(filename, "rb") as f:
        while True:
            data = f.read(4096)
            if not data:
                break
            m.update(data)
    return m.hexdigest()


def is_img(filename):
    return filename.lower().endswith(IMG_EXT)


def untar(path, to=None, remove=False):
    if to is None:
        to = os.path.dirname(path)
    with tarfile.open(path, "r") as tar:
        tar.extractall(path=to)

    if remove:
        os.remove(path)


def untargz(path, to=None, remove=False):
    if path.endswith(".tar.gz"):
        if to is None:
            to = os.path.dirname(path)
        with tarfile.open(path, "r:gz") as tar:
            tar.extractall(path=to)
    else:
        raise ValueError("path %s does not end with .tar" % path)

    if remove:
        os.remove(path)