megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

import megengine.module as M
from megengine.quantization import quantize, quantize_qat


def test_repr():
    class Net(M.Module):
        def __init__(self):
            super().__init__()
            self.conv_bn = M.ConvBnRelu2d(3, 3, 3)
            self.linear = M.Linear(3, 3)

        def forward(self, x):
            return x

    net = Net()
    ground_truth = (
        "Net(\n"
        "  (conv_bn): ConvBnRelu2d(\n"
        "    (conv): Conv2d(3, 3, kernel_size=(3, 3))\n"
        "    (bn): BatchNorm2d(3, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n"
        "  )\n"
        "  (linear): Linear(in_features=3, out_features=3, bias=True)\n"
        ")"
    )
    assert net.__repr__() == ground_truth
    quantize_qat(net)
    ground_truth = (
        "Net(\n"
        "  (conv_bn): QAT.ConvBnRelu2d(\n"
        "    (conv): Conv2d(3, 3, kernel_size=(3, 3))\n"
        "    (bn): BatchNorm2d(3, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n"
        "    (act_observer): ExponentialMovingAverageObserver()\n"
        "    (act_fake_quant): FakeQuantize()\n"
        "    (weight_observer): MinMaxObserver()\n"
        "    (weight_fake_quant): FakeQuantize()\n"
        "  )\n"
        "  (linear): QAT.Linear(\n"
        "    in_features=3, out_features=3, bias=True\n"
        "    (act_observer): ExponentialMovingAverageObserver()\n"
        "    (act_fake_quant): FakeQuantize()\n"
        "    (weight_observer): MinMaxObserver()\n"
        "    (weight_fake_quant): FakeQuantize()\n"
        "  )\n"
        ")"
    )
    assert net.__repr__() == ground_truth
    quantize(net)
    ground_truth = (
        "Net(\n"
        "  (conv_bn): Quantized.ConvBnRelu2d(3, 3, kernel_size=(3, 3))\n"
        "  (linear): Quantized.Linear()\n"
        ")"
    )
    assert net.__repr__() == ground_truth