megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
import io
from functools import partial
from itertools import chain
from typing import Callable

import numpy as np

import megengine as mge
import megengine.functional as F
import megengine.module as M
import megengine.module.qat as QM
import megengine.quantization as Q
from megengine import Tensor
from megengine.module.qat.module import QATModule
from megengine.traced_module import TracedModule, trace_module
from megengine.traced_module.utils import get_subattr


class MyConvBnRelu2d(M.ConvBnRelu2d):
    pass


class MyQATConvBnRelu2d(QM.ConvBnRelu2d):
    pass


class Myblcok(M.Module):
    def __init__(self,):
        super().__init__()
        self.conv0 = MyConvBnRelu2d(3, 3, 3, 1, 1)
        self.conv1 = M.ConvBn2d(3, 3, 1, 1, 0)
        self.conv2 = M.ConvBn2d(3, 3, 1, 1, 0)
        self.add = M.Elemwise("FUSE_ADD_RELU")

    def forward(self, x):
        x = self.conv0(x)
        x0 = self.conv1(x)
        x1 = self.conv2(x)
        o = self.add(x0, x1)
        return o


class MyModule(M.Module):
    def __init__(self):
        super().__init__()
        self.block0 = Myblcok()
        self.block1 = Myblcok()

    def forward(self, x):
        x = self.block0(x)
        x = self.block1(x)
        return x


class MyMinMaxObserver(Q.MinMaxObserver):
    pass


class MyTQT(Q.TQT):
    pass


def get_lsq_config(lsq_cls):
    return Q.QConfig(
        weight_observer=None,
        act_observer=None,
        weight_fake_quant=partial(lsq_cls, dtype="qint8_narrow"),
        act_fake_quant=partial(lsq_cls, dtype="qint8"),
    )


def get_observer_config(observer_cls):
    return Q.QConfig(
        weight_observer=partial(observer_cls, dtype="qint8_narrow"),
        act_observer=partial(observer_cls, dtype="qint8"),
        weight_fake_quant=None,
        act_fake_quant=None,
    )


def get_qparams(mod: QATModule):
    weight_qparams, act_qparams = None, None
    if mod.act_observer is not None:
        act_qparams = mod.act_observer.get_qparams()
    if mod.act_fake_quant:
        act_qparams = mod.act_fake_quant.get_qparams()

    if mod.weight_observer is not None:
        weight_qparams = mod.weight_observer.get_qparams()
    if mod.weight_fake_quant:
        weight_qparams = mod.weight_fake_quant.get_qparams()

    return weight_qparams, act_qparams


def check_qparams(qparmsa: Q.QParams, qparmsb: Q.QParams):
    assert qparmsa.dtype_meta == qparmsb.dtype_meta
    assert qparmsa.mode == qparmsb.mode
    np.testing.assert_equal(qparmsa.scale.numpy(), qparmsb.scale.numpy())
    if qparmsa.zero_point is not None:
        np.testing.assert_equal(qparmsa.zero_point.numpy(), qparmsb.zero_point.numpy())


def build_observered_net(net: M.Module, observer_cls):
    qat_net = Q.quantize_qat(
        net,
        qconfig=get_observer_config(observer_cls),
        mapping={MyConvBnRelu2d: MyQATConvBnRelu2d},
    )
    Q.enable_observer(qat_net)
    inp = Tensor(np.random.random(size=(5, 3, 32, 32)))
    qat_net.eval()
    qat_net(inp)
    Q.disable_observer(qat_net)
    return qat_net


def build_fakequanted_net(net: QATModule, fakequant_cls):
    qat_net = Q.reset_qconfig(net, get_lsq_config(fakequant_cls))
    qat_net.eval()
    return qat_net


def test_trace_qat():
    def _check_qat_module(qat_net: QATModule):
        inp = Tensor(np.random.random(size=(5, 3, 32, 32)))
        traced_net = trace_module(qat_net, inp)

        for name, qat_module in qat_net.named_modules():
            if not isinstance(qat_module, QATModule):
                continue
            traced_qat_module = get_subattr(traced_net, name)
            weight_qparams, act_qparams = get_qparams(qat_module)
            traced_weight_qparams, traced_act_qparams = get_qparams(traced_qat_module)
            if weight_qparams:
                check_qparams(weight_qparams, traced_weight_qparams)
            if act_qparams:
                check_qparams(act_qparams, traced_act_qparams)
        flatten_traced_net = traced_net.flatten()
        conv0_node = flatten_traced_net.graph.get_node_by_name(
            "MyModule_block0_conv0"
        ).as_unique()
        conv0_out_node = flatten_traced_net.graph.get_node_by_name(
            "MyModule_block0_conv0_out"
        ).as_unique()
        assert isinstance(conv0_node.owner, TracedModule)
        assert conv0_out_node.expr.inputs[0] is conv0_node

    _check_qat_module(build_observered_net(MyModule(), Q.MinMaxObserver))
    _check_qat_module(build_observered_net(MyModule(), MyMinMaxObserver))
    _check_qat_module(
        build_fakequanted_net(build_observered_net(MyModule(), Q.MinMaxObserver), Q.TQT)
    )
    _check_qat_module(
        build_fakequanted_net(build_observered_net(MyModule(), Q.MinMaxObserver), MyTQT)
    )


def test_load_param():
    def _check_param(moda: M.Module, modb: M.Module):
        for name, attr in chain(moda.named_parameters(), moda.named_buffers()):
            traced_attr = get_subattr(modb, name)
            np.testing.assert_equal(attr.numpy(), traced_attr.numpy())

    def _check_module(build_func: Callable):
        net = build_func()
        net.eval()
        buffer = io.BytesIO()
        mge.save(net.state_dict(), buffer)
        buffer.seek(0)

        inp = Tensor(np.random.random(size=(5, 3, 32, 32)))
        traced_net = trace_module(build_func(), inp)
        traced_net.load_state_dict(mge.load(buffer))

        _check_param(net, traced_net)

        buffer.seek(0)
        traced_net = trace_module(build_func(), inp).flatten()
        traced_net.load_state_dict(mge.load(buffer))

        _check_param(net, traced_net)

    _check_module(lambda: MyModule())
    _check_module(lambda: build_observered_net(MyModule(), Q.MinMaxObserver))


def test_qualname():
    def _check_qualname(net):
        inp = Tensor(np.random.random(size=(5, 3, 32, 32)))
        net.eval()
        traced_net = trace_module(net, inp)
        base_qualname = traced_net.graph.qualname
        for node in traced_net.graph.nodes():
            qualname = node.qualname
            qualname = qualname[len(base_qualname) + 1 :]
            if qualname.endswith("]"):
                qualname = qualname.rsplit(".", 1)[0]
            if qualname.startswith("["):
                qualname = ""
            traced_attr = get_subattr(traced_net, qualname)
            orig_attr = get_subattr(net, qualname)
            assert traced_attr is not None
            assert orig_attr is not None

    _check_qualname(MyModule())
    _check_qualname(build_observered_net(MyModule(), Q.MinMaxObserver))