megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import pickle

from .device import _valid_device, get_default_device
from .tensor import Tensor
from .utils.max_recursion_limit import max_recursion_limit


def save(obj, f, pickle_module=pickle, pickle_protocol=pickle.DEFAULT_PROTOCOL):
    r"""Save an object to disk file.

    Args:
        obj: object to save. Only ``module`` or ``state_dict`` are allowed.
        f: a string of file name or a text file object to which ``obj`` is saved to.
        pickle_module: Default: ``pickle``.
        pickle_protocol: Default: ``pickle.DEFAULT_PROTOCOL``.
    """
    if isinstance(f, str):
        with open(f, "wb") as fout:
            save(
                obj, fout, pickle_module=pickle_module, pickle_protocol=pickle_protocol
            )
        return

    with max_recursion_limit():
        assert hasattr(f, "write"), "{} does not support write".format(f)
        pickle_module.dump(obj, f, pickle_protocol)


class dmap:
    def __init__(self, map_location):
        self.map_location = map_location

    def __enter__(self):
        Tensor.dmap_callback = staticmethod(self.map_location)
        return self

    def __exit__(self, type, value, traceback):
        Tensor.dmap_callback = None


def _get_callable_map_location(map_location):
    if map_location is None:

        def callable_map_location(state):
            return state

    elif isinstance(map_location, str):

        def callable_map_location(state):
            return map_location

    elif isinstance(map_location, dict):
        for key, value in map_location.items():
            # dict key and values can only be "xpux", "cpux", "gpu0", etc.
            assert _valid_device(key), "Invalid locator_map key value {}".format(key)
            assert _valid_device(value), "Invalid locator_map key value {}".format(
                value
            )

        def callable_map_location(state):
            if state[:4] in map_location.keys():
                state = map_location[state[:4]]
            return state

    else:
        assert callable(map_location), "map_location should be str, dict or function"
        callable_map_location = map_location
    return callable_map_location


def load(f, map_location=None, pickle_module=pickle):
    r"""Load an object saved with :func:~.megengine.save` from a file.

    Args:
        f: a string of file name or a text file object from which to load.
        map_location: Default: ``None``.
        pickle_module: Default: ``pickle``.

    Note:
       * ``map_location`` defines device mapping. See examples for usage.
       * If you will call :func:`~.megengine.set_default_device()`, please do it
         before :func:`~.megengine.load()`.

    Examples:
        .. code-block::

           import megengine as mge
           # Load tensors to the same device as defined in model.pkl
           mge.load('model.pkl')
           # Load all tensors to gpu0.
           mge.load('model.pkl', map_location='gpu0')
           # Load all tensors originally on gpu0 to cpu0
           mge.load('model.pkl', map_location={'gpu0':'cpu0'})
           # Load all tensors to cpu0
           mge.load('model.pkl', map_location=lambda dev: 'cpu0')
    """
    if isinstance(f, str):
        with open(f, "rb") as fin:
            return load(fin, map_location=map_location, pickle_module=pickle_module)

    map_location = _get_callable_map_location(map_location)  # callable map_location

    with dmap(map_location) as dm:
        return pickle_module.load(f)