megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np

from megengine.data.transform import *

data_shape = (100, 100, 3)
label_shape = (4,)
ToMode_target_shape = (3, 100, 100)
CenterCrop_size = (90, 70)
CenterCrop_target_shape = CenterCrop_size + (3,)
RandomResizedCrop_size = (50, 50)
RandomResizedCrop_target_shape = RandomResizedCrop_size + (3,)


def generate_data():
    return [
        (
            (np.random.rand(*data_shape) * 255).astype(np.uint8),
            np.random.randint(10, size=label_shape),
        )
        for _ in range(*label_shape)
    ]


def test_ToMode():
    t = ToMode(mode="CHW")
    aug_data = t.apply_batch(generate_data())
    aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
    target_shape = [(ToMode_target_shape, label_shape)] * 4
    assert aug_data_shape == target_shape


def test_CenterCrop():
    t = CenterCrop(output_size=CenterCrop_size)
    aug_data = t.apply_batch(generate_data())
    aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
    target_shape = [(CenterCrop_target_shape, label_shape)] * 4
    assert aug_data_shape == target_shape


def test_ColorJitter():
    t = ColorJitter()
    aug_data = t.apply_batch(generate_data())
    aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
    target_shape = [(data_shape, label_shape)] * 4
    assert aug_data_shape == target_shape


def test_RandomHorizontalFlip():
    t = RandomHorizontalFlip(prob=1)
    aug_data = t.apply_batch(generate_data())
    aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
    target_shape = [(data_shape, label_shape)] * 4
    assert aug_data_shape == target_shape


def test_RandomVerticalFlip():
    t = RandomVerticalFlip(prob=1)
    aug_data = t.apply_batch(generate_data())
    aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
    target_shape = [(data_shape, label_shape)] * 4
    assert aug_data_shape == target_shape


def test_RandomResizedCrop():
    t = RandomResizedCrop(output_size=RandomResizedCrop_size)
    aug_data = t.apply_batch(generate_data())
    aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
    target_shape = [(RandomResizedCrop_target_shape, label_shape)] * 4
    assert aug_data_shape == target_shape


def test_Normalize():
    t = Normalize()
    aug_data = t.apply_batch(generate_data())
    aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
    target_shape = [(data_shape, label_shape)] * 4
    assert aug_data_shape == target_shape


def test_RandomCrop():
    t = RandomCrop((150, 120), padding_size=10, padding_value=[1, 2, 3])
    aug_data = t.apply_batch(generate_data())
    aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
    target_shape = [((150, 120, 3), label_shape)] * 4
    assert aug_data_shape == target_shape


def test_Compose():
    t = Compose(
        [
            CenterCrop(output_size=CenterCrop_size),
            RandomHorizontalFlip(prob=1),
            ToMode(mode="CHW"),
        ]
    )
    aug_data = t.apply_batch(generate_data())
    aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
    target_shape = [((3, 90, 70), label_shape)] * 4
    assert aug_data_shape == target_shape, "aug {}, target {}".format(
        aug_data_shape, target_shape
    )