import pickle
from .device import _valid_device, get_default_device
from .tensor import Tensor
from .utils.max_recursion_limit import max_recursion_limit
def save(obj, f, pickle_module=pickle, pickle_protocol=pickle.DEFAULT_PROTOCOL):
if isinstance(f, str):
with open(f, "wb") as fout:
save(
obj, fout, pickle_module=pickle_module, pickle_protocol=pickle_protocol
)
return
with max_recursion_limit():
assert hasattr(f, "write"), "{} does not support write".format(f)
pickle_module.dump(obj, f, pickle_protocol)
class dmap:
def __init__(self, map_location):
self.map_location = map_location
def __enter__(self):
Tensor.dmap_callback = staticmethod(self.map_location)
return self
def __exit__(self, type, value, traceback):
Tensor.dmap_callback = None
def _get_callable_map_location(map_location):
if map_location is None:
def callable_map_location(state):
return state
elif isinstance(map_location, str):
def callable_map_location(state):
return map_location
elif isinstance(map_location, dict):
for key, value in map_location.items():
assert _valid_device(key), "Invalid locator_map key value {}".format(key)
assert _valid_device(value), "Invalid locator_map key value {}".format(
value
)
def callable_map_location(state):
if state[:4] in map_location.keys():
state = map_location[state[:4]]
return state
else:
assert callable(map_location), "map_location should be str, dict or function"
callable_map_location = map_location
return callable_map_location
def load(f, map_location=None, pickle_module=pickle):
if isinstance(f, str):
with open(f, "rb") as fin:
return load(fin, map_location=map_location, pickle_module=pickle_module)
map_location = _get_callable_map_location(map_location)
with dmap(map_location) as dm:
return pickle_module.load(f)