megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

import argparse
import contextlib
import getpass
import os
import sys
import urllib.parse

import filelock

from ..core._imperative_rt import PersistentCache as _PersistentCache
from ..logger import get_logger
from ..version import __version__, git_version


class PersistentCacheOnServer(_PersistentCache):
    def __init__(self):
        super().__init__()
        cache_type = os.getenv("MGE_FASTRUN_CACHE_TYPE")
        if cache_type not in ("FILE", "MEMORY"):
            try:
                redis_config = self.get_redis_config()
            except Exception as exc:
                get_logger().error(
                    "failed to connect to cache server {!r}; try fallback to "
                    "in-file cache".format(exc)
                )
            else:
                if redis_config is not None:
                    self.add_config(
                        "redis",
                        redis_config,
                        "fastrun use redis cache",
                        "failed to connect to cache server",
                    )
        if cache_type != "MEMORY":
            path = self.get_cache_file(self.get_cache_dir())
            self.add_config(
                "in-file",
                {"path": path},
                "fastrun use in-file cache in {}".format(path),
                "failed to create cache file in {}".format(path),
            )
        self.add_config(
            "in-memory",
            {},
            "fastrun use in-memory cache",
            "failed to create in-memory cache",
        )

    def get_cache_dir(self):
        cache_dir = os.getenv("MGE_FASTRUN_CACHE_DIR")
        if not cache_dir:
            from ..hub.hub import _get_megengine_home

            cache_dir = os.path.expanduser(
                os.path.join(_get_megengine_home(), "persistent_cache")
            )
        os.makedirs(cache_dir, exist_ok=True)
        return cache_dir

    def get_cache_file(self, cache_dir):
        cache_file = os.path.join(cache_dir, "cache.bin")
        with open(cache_file, "a"):
            pass
        return cache_file

    @contextlib.contextmanager
    def lock_cache_file(self, cache_dir):
        lock_file = os.path.join(cache_dir, "cache.lock")
        with filelock.FileLock(lock_file):
            yield

    def get_redis_config(self):
        url = os.getenv("MGE_FASTRUN_CACHE_URL")
        if url is None:
            return None
        assert sys.platform != "win32", "redis cache on windows not tested"
        prefix = "mgbcache:{}:MGB{}:GIT:{}".format(
            getpass.getuser(), __version__, git_version
        )
        parse_result = urllib.parse.urlparse(url)
        assert not parse_result.username, "redis conn with username unsupported"
        if parse_result.scheme == "redis":
            assert parse_result.hostname and parse_result.port, "invalid url"
            assert not parse_result.path
            config = {
                "hostname": parse_result.hostname,
                "port": str(parse_result.port),
            }
        elif parse_result.scheme == "redis+socket":
            assert not (parse_result.hostname or parse_result.port)
            assert parse_result.path
            config = {
                "unixsocket": parse_result.path,
            }
        else:
            assert False, "unsupported scheme"
        if parse_result.password is not None:
            config["password"] = parse_result.password
        config["prefix"] = prefix
        return config

    def flush(self):
        if self.config is not None and self.config.type == "in-file":
            with self.lock_cache_file(self.get_cache_dir()):
                super().flush()


def _clean():
    nr_del = PersistentCacheOnServer().clean()
    if nr_del is not None:
        print("{} cache entries deleted".format(nr_del))


def main():
    parser = argparse.ArgumentParser(description="manage persistent cache")
    subp = parser.add_subparsers(description="action to be performed", dest="cmd")
    subp.required = True
    subp_clean = subp.add_parser("clean", help="clean all the cache of current user")
    subp_clean.set_defaults(action=_clean)
    args = parser.parse_args()
    args.action()


if __name__ == "__main__":
    main()