from typing import Union
import numpy as np
from .core._imperative_rt import CompNode
from .core._imperative_rt.core2 import Tensor as _Tensor
from .core._imperative_rt.core2 import apply, set_py_tensor_type
from .core._trace_option import use_symbolic_shape
from .core._wrap import as_device
from .core.ops.builtin import Copy, GetVarShape
from .core.tensor.array_method import ArrayMethodMixin
from .device import _valid_device, get_default_device
from .logger import get_logger
from .utils.deprecation import deprecated
logger = get_logger(__name__)
class Tensor(_Tensor, ArrayMethodMixin):
grad = None
dmap_callback = None
_qparams = None
_custom_name = ""
_name = None
_short_name = None
_prefix = None
def __new__(
cls,
data: Union["Tensor", np.ndarray, list, int, float] = None,
dtype: np.dtype = None,
device: str = None,
is_const: bool = False,
no_cache: bool = False,
name: str = None,
):
if data is None:
data = []
if device is None:
cn = get_default_device()
elif isinstance(device, str):
if cls.dmap_callback is not None:
cn = CompNode(cls.dmap_callback(device))
else:
cn = CompNode(device)
else:
if isinstance(device, CompNode):
cn = device
else:
cn = device._cn
if isinstance(data, _Tensor):
obj = _Tensor.__new__(cls, data)
else:
if isinstance(data, np.ndarray):
if 0 in data.strides:
data = data.squeeze().reshape(data.shape)
obj = _Tensor.__new__(cls, data, dtype, cn, is_const, no_cache, name)
return obj
def __init__(
self,
data: Union["Tensor", np.ndarray, list, int, float],
dtype: np.dtype = None,
device: str = None,
is_const: bool = False,
no_cache: bool = False,
name: str = None,
):
if name is None:
name = ""
else:
self._set_name(name)
self._custom_name = name
self._name = name
self._short_name = name
self._prefix = None
@property
def shape(self) -> Union[tuple, "Tensor"]:
shape = super().shape
if shape == () or not use_symbolic_shape():
return shape
return apply(GetVarShape(), self)[0]
@property
def _tuple_shape(self):
return super().shape
@property
def device(self) -> CompNode:
return super().device
@property
def dtype(self) -> np.dtype:
return super().dtype
@property
def qparams(self):
from .quantization.utils import create_qparams
if self._qparams is None:
self._qparams = create_qparams()
return self._qparams
def numpy(self) -> np.ndarray:
return super().numpy()
def detach(self):
return super().detach()
def _reset(self, other):
if not isinstance(other, _Tensor):
other = Tensor(other, dtype=self.dtype, device=self.device)
super()._reset(other)
def __repr__(self):
piece = "{}(".format(self.__class__.__name__)
with np.printoptions(precision=4, suppress=True):
piece += "{}".format(str(self.numpy()))
if self.dtype != np.float32:
piece += ", dtype={}".format(np.dtype(self.dtype).name)
piece += ", device={}".format(self.device) + ")"
return piece
@property
def name(self):
return self._custom_name
@name.setter
def name(self, name):
self._custom_name = name
self._name = self._prefix + "." + name if self._prefix else name
self._set_name(self._name)
@deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0")
def set_value(self, value):
self._reset(value)
@deprecated(version="1.0", reason="use ``*= 0`` instead")
def reset_zero(self):
self *= 0
def to(self, device):
if isinstance(device, str) and not _valid_device(device):
raise ValueError(
"invalid device name {}. For the correct format of the device name, please refer to the instruction of megengine.device.set_default_device()".format(
device
)
)
cn = as_device(device).to_c()
return apply(Copy(comp_node=cn), self)[0]
@property
def requires_grad(self):
raise AttributeError("requires_grad is reserved for future use")
@requires_grad.setter
def requires_grad(self, value):
raise AttributeError("requires_grad is reserved for future use")
@requires_grad.deleter
def requires_grad(self):
raise AttributeError("requires_grad is reserved for future use")
def __hash__(self):
return id(self)
def __getnewargs__(self):
return (self.numpy(), self.dtype, self.device.logical_name)
def __getstate__(self):
state = {}
if self._qparams is not None:
state["qparams"] = self._qparams
return state
def __setstate__(self, state):
if "data" in state:
data = state.pop("data")
device = state.pop("device")
dtype = state.pop("dtype")
self._reset(Tensor(data, dtype=dtype, device=device))
if "qdict" in state:
qparams = state.pop("qdict")
logger.warning(
"Tensor's 'qdict' state is depreciated. Use 'qparams' instead"
)
elif "qparams" in state:
qparams = state.pop("qparams")
else:
qparams = None
self._qparams = qparams
set_py_tensor_type(Tensor)
tensor = Tensor
class Parameter(Tensor):