megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
import io
import pickle

import numpy as np

import megengine.functional as F
import megengine.module as M
import megengine.utils.comp_graph_tools as cgtools
from megengine.core._trace_option import set_symbolic_shape
from megengine.jit import trace
from megengine.traced_module import trace_module

set_symbolic_shape(True)


class Main(M.Module):
    def forward(self, x):
        return x


class PreProcess(M.Module):
    def __init__(self):
        super().__init__()
        self.I = F.ones((1,))
        self.M = F.zeros((1,))

    def forward(self, data, idx, roi):
        N, H, W, C = data.shape
        xmax = roi[:, 1, 0]
        xmin = roi[:, 0, 0]
        ymax = roi[:, 1, 1]
        ymin = roi[:, 0, 1]
        scale = F.maximum((xmax - xmin) / W, (ymax - ymin) / H)
        I = F.broadcast_to(self.I, (N,))
        M = F.broadcast_to(self.M, (N, 3, 3))
        M[:, 0, 0] = scale
        M[:, 0, 2] = xmin
        M[:, 1, 1] = scale
        M[:, 1, 2] = ymin
        M[:, 2, 2] = I
        resized = (
            F.warp_perspective(
                data, M, (H, W), mat_idx=idx, border_mode="CONSTANT", format="NHWC"
            )
            .transpose(0, 3, 1, 2)
            .astype(np.float32)
        )
        return resized


class Net(M.Module):
    def __init__(self, traced_module):
        super().__init__()
        self.pre_process = PreProcess()
        self.traced_module = traced_module

    def forward(self, data, idx, roi):
        x = self.pre_process(data, idx, roi)
        x = self.traced_module(x)
        return x


def test_preprocess():
    module = Main()
    data = F.ones((1, 14, 8, 8), dtype=np.uint8)
    traced_module = trace_module(module, data)
    obj = pickle.dumps(traced_module)
    traced_module = pickle.loads(obj)
    module = Net(traced_module)
    module.eval()
    idx = F.zeros((1,), dtype=np.int32)
    roi = F.ones((1, 2, 2), dtype=np.float32)
    y = module(data, idx, roi)
    traced_module = trace_module(module, data, idx, roi)
    np.testing.assert_array_equal(traced_module(data, idx, roi), y)
    func = trace(traced_module, capture_as_const=True)
    np.testing.assert_array_equal(func(data, idx, roi), y)
    model = io.BytesIO()
    func.dump(model, arg_names=("data", "idx", "roi"))
    model.seek(0)
    infer_cg = cgtools.GraphInference(model)
    np.testing.assert_allclose(
        list(
            infer_cg.run(
                inp_dict={"data": data.numpy(), "idx": idx.numpy(), "roi": roi.numpy()}
            ).values()
        )[0],
        y,
        atol=1e-6,
    )