megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

import contextlib
import os
import tempfile

import numpy as np
import pytest

import megengine as mge
import megengine.functional as F
import megengine.module as M
import megengine.optimizer as optim
from megengine import tensor
from megengine.autodiff import GradManager
from megengine.jit import trace
from megengine.traced_module import trace_module


@contextlib.contextmanager
def mkstemp():
    fd, path = tempfile.mkstemp()
    try:
        os.close(fd)
        yield path
    finally:
        os.remove(path)


def minibatch_generator(batch_size):
    while True:
        inp_data = np.zeros((batch_size, 2))
        label = np.zeros(batch_size, dtype=np.int32)
        for i in range(batch_size):
            inp_data[i, :] = np.random.rand(2) * 2 - 1
            label[i] = 1 if np.prod(inp_data[i]) < 0 else 0
        yield {"data": inp_data.astype(np.float32), "label": label.astype(np.int32)}


class XORNet(M.Module):
    def __init__(self):
        self.mid_dim = 14
        self.num_class = 2
        super().__init__()
        self.fc0 = M.Linear(self.num_class, self.mid_dim, bias=True)
        self.bn0 = M.BatchNorm1d(self.mid_dim)
        self.fc1 = M.Linear(self.mid_dim, self.mid_dim, bias=True)
        self.bn1 = M.BatchNorm1d(self.mid_dim)
        self.fc2 = M.Linear(self.mid_dim, self.num_class, bias=True)

    def forward(self, x):
        x = self.fc0(x)
        x = self.bn0(x)
        x = F.tanh(x)
        x = self.fc1(x)
        x = self.bn1(x)
        x = F.tanh(x)
        x = self.fc2(x)
        return x


def test_xornet_trace_dump():
    net = XORNet()
    opt = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
    gm = GradManager().attach(net.parameters())
    batch_size = 64
    train_dataset = minibatch_generator(batch_size)
    val_dataset = minibatch_generator(batch_size)

    @trace
    def train_fun(data, label):
        with gm:
            net.train()
            pred = net(data)
            loss = F.nn.cross_entropy(pred, label)
            gm.backward(loss)
        return pred, loss

    @trace
    def val_fun(data, label):
        net.eval()
        pred = net(data)
        loss = F.nn.cross_entropy(pred, label)
        return pred, loss

    @trace(symbolic=True, capture_as_const=True)
    def pred_fun(data):
        net.eval()
        pred = net(data)
        pred_normalized = F.softmax(pred)
        return pred_normalized

    train_loss = []
    val_loss = []
    for step, minibatch in enumerate(train_dataset):
        if step > 100:
            break
        data = tensor(minibatch["data"])
        label = tensor(minibatch["label"])
        opt.clear_grad()
        _, loss = train_fun(data, label)
        train_loss.append((step, loss.numpy()))
        if step % 50 == 0:
            minibatch = next(val_dataset)
            _, loss = val_fun(data, label)
            loss = loss.numpy()
            val_loss.append((step, loss))
        opt.step()

    test_data = np.array(
        [
            (0.5, 0.5),
            (0.3, 0.7),
            (0.1, 0.9),
            (-0.5, -0.5),
            (-0.3, -0.7),
            (-0.9, -0.1),
            (0.5, -0.5),
            (0.3, -0.7),
            (0.9, -0.1),
            (-0.5, 0.5),
            (-0.3, 0.7),
            (-0.1, 0.9),
        ]
    )

    data = tensor(test_data.astype(np.float32))
    out = pred_fun(data)

    with mkstemp() as out:
        pred_fun.dump(out, arg_names=["data"], output_names=["label"])


def test_dump_bn_train_mode():
    @trace(symbolic=True, capture_as_const=True)
    def bn_train(data):
        pred = M.BatchNorm2d(10)(data).sum()
        return pred

    data = mge.tensor(np.random.random((10, 10, 10, 10)))
    bn_train(data)
    with pytest.raises(RuntimeError):
        bn_train.dump("test.mge")