megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

from copy import deepcopy
from typing import Union

from ..core.tensor.dtype import QuantDtypeMeta
from ..quantization.fake_quant import QParamsModuleMixin, _FakeQuantize
from ..quantization.utils import QParams, QuantMode, fake_quant_tensor


class FakeQuantize(_FakeQuantize, QParamsModuleMixin):
    r"""A module to do quant and dequant according to :attr:`~.FakeQuantize.qparams`."""

    def __init__(
        self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs
    ):
        super().__init__(dtype, enable, **kwargs)
        self.qparams = None

    def fake_quant_forward(self, inp, qparams: QParams = None):
        if qparams is None:
            qparams = self.get_qparams()
        assert (
            qparams.dtype_meta is self.dtype
        ), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format(
            qparams.dtype_meta, self.dtype
        )
        return fake_quant_tensor(inp, qparams)

    def get_qparams(self):
        return self.qparams

    def set_qparams(self, qparams: QParams):
        r"""Initialize :attr:`~.FakeQuantize.qparams`.
        
        Args:
            qparams: used to set initial ``scale`` and ``zero_point``.
        """
        if qparams.scale is None:
            raise AssertionError("Can not get an initialized scale")
        scale = qparams.scale
        if qparams.dtype_meta is None:
            qparams.dtype_meta = self.dtype
        else:
            assert (
                qparams.dtype_meta is self.dtype
            ), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format(
                qparams.dtype_meta, self.dtype
            )
        dtype_meta = qparams.dtype_meta
        zero_point = qparams.zero_point
        mode = qparams.mode

        self.qparams = QParams(mode, dtype_meta, scale, zero_point)