import functools
import hashlib
import os
import sys
import types
from typing import Any, List
from urllib.parse import urlparse
from megengine.utils.http_download import download_from_url
from ..distributed import is_distributed
from ..logger import get_logger
from ..serialization import load as _mge_load_serialized
from .const import (
DEFAULT_CACHE_DIR,
DEFAULT_GIT_HOST,
DEFAULT_PROTOCOL,
ENV_MGE_HOME,
ENV_XDG_CACHE_HOME,
HUBCONF,
HUBDEPENDENCY,
)
from .exceptions import InvalidProtocol
from .fetcher import GitHTTPSFetcher, GitSSHFetcher
from .tools import cd, check_module_exists, load_module
logger = get_logger(__name__)
PROTOCOLS = {
"HTTPS": GitHTTPSFetcher,
"SSH": GitSSHFetcher,
}
def _get_megengine_home() -> str:
megengine_home = os.path.expanduser(
os.getenv(
ENV_MGE_HOME,
os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "megengine"),
)
)
return megengine_home
def _get_repo(
git_host: str,
repo_info: str,
use_cache: bool = False,
commit: str = None,
protocol: str = DEFAULT_PROTOCOL,
) -> str:
if protocol not in PROTOCOLS:
raise InvalidProtocol(
"Invalid protocol, the value should be one of {}.".format(
", ".join(PROTOCOLS.keys())
)
)
cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub"))
with cd(cache_dir):
fetcher = PROTOCOLS[protocol]
repo_dir = fetcher.fetch(git_host, repo_info, use_cache, commit)
return os.path.join(cache_dir, repo_dir)
def _check_dependencies(module: types.ModuleType) -> None:
if not hasattr(module, HUBDEPENDENCY):
return
dependencies = getattr(module, HUBDEPENDENCY)
if not dependencies:
return
missing_deps = [m for m in dependencies if not check_module_exists(m)]
if len(missing_deps):
raise RuntimeError("Missing dependencies: {}".format(", ".join(missing_deps)))
def _init_hub(
repo_info: str,
git_host: str,
use_cache: bool = True,
commit: str = None,
protocol: str = DEFAULT_PROTOCOL,
):
cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub"))
os.makedirs(cache_dir, exist_ok=True)
absolute_repo_dir = _get_repo(
git_host, repo_info, use_cache=use_cache, commit=commit, protocol=protocol
)
sys.path.insert(0, absolute_repo_dir)
hubmodule = load_module(HUBCONF, os.path.join(absolute_repo_dir, HUBCONF))
sys.path.remove(absolute_repo_dir)
return hubmodule
@functools.wraps(_init_hub)
def import_module(*args, **kwargs):
return _init_hub(*args, **kwargs)
def list(
repo_info: str,
git_host: str = DEFAULT_GIT_HOST,
use_cache: bool = True,
commit: str = None,
protocol: str = DEFAULT_PROTOCOL,
) -> List[str]:
hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol)
return [
_
for _ in dir(hubmodule)
if not _.startswith("__") and callable(getattr(hubmodule, _))
]
def load(
repo_info: str,
entry: str,
*args,
git_host: str = DEFAULT_GIT_HOST,
use_cache: bool = True,
commit: str = None,
protocol: str = DEFAULT_PROTOCOL,
**kwargs
) -> Any:
hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol)
if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)):
raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry))
_check_dependencies(hubmodule)
module = getattr(hubmodule, entry)(*args, **kwargs)
return module
def help(
repo_info: str,
entry: str,
git_host: str = DEFAULT_GIT_HOST,
use_cache: bool = True,
commit: str = None,
protocol: str = DEFAULT_PROTOCOL,
) -> str:
hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol)
if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)):
raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry))
doc = getattr(hubmodule, entry).__doc__
return doc
def load_serialized_obj_from_url(url: str, model_dir=None) -> Any:
if model_dir is None:
model_dir = os.path.join(_get_megengine_home(), "serialized")
os.makedirs(model_dir, exist_ok=True)
parts = urlparse(url)
filename = os.path.basename(parts.path)
sha256 = hashlib.sha256()
sha256.update(url.encode())
digest = sha256.hexdigest()[:6]
filename = digest + "_" + filename
cached_file = os.path.join(model_dir, filename)
logger.info(
"load_serialized_obj_from_url: download to or using cached %s", cached_file
)
if not os.path.exists(cached_file):
if is_distributed():
logger.warning(
"Downloading serialized object in DISTRIBUTED mode\n"
" File may be downloaded multiple times. We recommend\n"
" users to download in single process first."
)
download_from_url(url, cached_file)
state_dict = _mge_load_serialized(cached_file)
return state_dict
class pretrained:
def __init__(self, url):
self.url = url
def __call__(self, func):
@functools.wraps(func)
def pretrained_model_func(
pretrained=False, **kwargs
): model = func(**kwargs)
if pretrained:
weights = load_serialized_obj_from_url(self.url)
model.load_state_dict(weights)
return model
return pretrained_model_func
__all__ = [
"list",
"load",
"help",
"load_serialized_obj_from_url",
"pretrained",
"import_module",
]